| | """ |
| | This module contains tooling to compare weights and activations |
| | across models. Example usage:: |
| | |
| | import copy |
| | import torch |
| | import torch.quantization.quantize_fx as quantize_fx |
| | import torch.ao.ns._numeric_suite_fx as ns |
| | |
| | m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval() |
| | mp = quantize_fx.prepare_fx(m, {'': torch.quantization.default_qconfig}) |
| | # We convert a copy because we need the original prepared model |
| | # to be available for comparisons, and `quantize_fx.convert_fx` is inplace. |
| | mq = quantize_fx.convert_fx(copy.deepcopy(mp)) |
| | |
| | # |
| | # Comparing weights |
| | # |
| | |
| | # extract weight pairs |
| | weight_comparison = ns.extract_weights('a', mp, 'b', mq) |
| | |
| | # add SQNR for each comparison, inplace |
| | ns.extend_logger_results_with_comparison( |
| | weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, |
| | 'sqnr') |
| | |
| | # weight_comparison contains the weights from `mp` and `mq` stored |
| | # in pairs, and can be used for further analysis. |
| | |
| | |
| | # |
| | # Comparing activations, with error propagation |
| | # |
| | |
| | # add loggers |
| | mp_ns, mq_ns = ns.add_loggers( |
| | 'a', copy.deepcopy(mp), |
| | 'b', copy.deepcopy(mq), |
| | ns.OutputLogger) |
| | |
| | # send an example datum to capture intermediate activations |
| | datum = torch.randn(1, 1, 1, 1) |
| | mp_ns(datum) |
| | mq_ns(datum) |
| | |
| | # extract intermediate activations |
| | act_comparison = ns.extract_logger_info( |
| | mp_ns, mq_ns, ns.OutputLogger, 'b') |
| | |
| | # add SQNR for each comparison, inplace |
| | ns.extend_logger_results_with_comparison( |
| | act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, |
| | 'sqnr') |
| | |
| | # act_comparison contains the activations from `mp_ns` and `mq_ns` stored |
| | # in pairs, and can be used for further analysis. |
| | |
| | # |
| | # Comparing activations, without error propagation |
| | # |
| | |
| | # create shadow model |
| | mp_shadows_mq = ns.add_shadow_loggers( |
| | 'a', copy.deepcopy(mp), |
| | 'b', copy.deepcopy(mq), |
| | ns.OutputLogger) |
| | |
| | # send an example datum to capture intermediate activations |
| | datum = torch.randn(1, 1, 1, 1) |
| | mp_shadows_mq(datum) |
| | |
| | # extract intermediate activations |
| | shadow_act_comparison = ns.extract_shadow_logger_info( |
| | mp_shadows_mq, ns.OutputLogger, 'b') |
| | |
| | # add SQNR for each comparison, inplace |
| | ns.extend_logger_results_with_comparison( |
| | shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr, |
| | 'sqnr') |
| | |
| | # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored |
| | # in pairs, and can be used for further analysis. |
| | |
| | """ |
| |
|
| | import collections |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.ao.quantization.quantize_fx as quantize_fx |
| | from torch.fx import GraphModule |
| | from torch.fx.graph import Node |
| | from torch.ao.ns.fx.mappings import ( |
| | get_base_name_to_sets_of_related_ops, |
| | ) |
| | from torch.ao.ns.fx.graph_matcher import ( |
| | get_matching_subgraph_pairs, |
| | get_type_a_related_to_b, |
| | ) |
| |
|
| | from .fx.weight_utils import ( |
| | extract_weight_from_node, |
| | ) |
| |
|
| | from .fx.graph_passes import ( |
| | add_loggers_to_model, |
| | create_a_shadows_b, |
| | ) |
| |
|
| | from .fx.utils import ( |
| | rekey_logger_info_on_node_name_of_model, |
| | maybe_add_missing_fqns, |
| | get_target_type_str, |
| | ) |
| |
|
| | from .fx.ns_types import ( |
| | NSSingleResultValuesType, |
| | NSResultsType, |
| | NSNodeTargetType, |
| | ) |
| |
|
| | from typing import Dict, Tuple, Callable, List, Optional, Set |
| |
|
| | RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] |
| |
|
| | class OutputLogger(nn.Module): |
| | """ |
| | Base class for capturing intermediate values. |
| | """ |
| | stats: List[torch.Tensor] |
| | stats_rnn: List[RNNReturnType] |
| |
|
| | |
| | _is_impure = True |
| |
|
| | def __init__( |
| | self, |
| | ref_node_name: str, |
| | prev_node_name: str, |
| | model_name: str, |
| | ref_name: str, |
| | prev_node_target_type: str, |
| | ref_node_target_type: str, |
| | results_type: str, |
| | index_within_arg: int, |
| | index_of_arg: int, |
| | fqn: Optional[str], |
| | ): |
| | super().__init__() |
| | self.stats: List[torch.Tensor] = [] |
| | self.stats_rnn: List[RNNReturnType] = [] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.ref_node_name = ref_node_name |
| | |
| | self.prev_node_name = prev_node_name |
| |
|
| | |
| | self.model_name = model_name |
| | |
| | |
| | self.ref_name = ref_name |
| | |
| | self.prev_node_target_type = prev_node_target_type |
| | |
| | |
| | self.ref_node_target_type = ref_node_target_type |
| | |
| | self.results_type = results_type |
| | |
| | |
| | self.index_within_arg = index_within_arg |
| | |
| | |
| | self.index_of_arg = index_of_arg |
| | |
| | self.fqn = fqn |
| |
|
| | |
| | |
| | def forward(self, x): |
| | """ |
| | """ |
| | if isinstance(x, torch.Tensor): |
| | self.stats.append(x.detach()) |
| | elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2: |
| | new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach())) |
| | self.stats_rnn.append(new_res) |
| | return x |
| |
|
| | def __repr__(self): |
| | return f"""OutputLogger(ref_name={self.ref_name}, model_name={self.model_name}, |
| | prev_node_name={self.prev_node_name}, ref_node_name={self.ref_node_name}, |
| | ref_node_target_type={self.ref_node_target_type} |
| | results_type={self.results_type}, index_within_arg={self.index_within_arg}, |
| | index_of_arg={self.index_of_arg}, fqn={self.fqn})""" |
| |
|
| |
|
| | class NSTracer(quantize_fx.QuantizationTracer): |
| | """ |
| | Just like a regular FX quantization tracer, but treats observers and fake_quantize |
| | modules as leaf modules. |
| | """ |
| | def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: |
| | """ |
| | """ |
| | if isinstance(m, torch.ao.quantization.ObserverBase): |
| | return True |
| | elif isinstance(m, torch.ao.quantization.FakeQuantizeBase): |
| | return True |
| | return super().is_leaf_module(m, module_qualified_name) |
| |
|
| |
|
| | def _extract_weights_one_model( |
| | model_name: str, |
| | model: GraphModule, |
| | nodes_and_names_to_instrument: List[Tuple[Node, str]], |
| | results: NSResultsType, |
| | op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None, |
| | ) -> None: |
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model") |
| | for node, ref_name in nodes_and_names_to_instrument: |
| | res_type = NSSingleResultValuesType.WEIGHT.value |
| | extracted_weight = extract_weight_from_node( |
| | node, model, op_to_type_to_weight_extraction_fn) |
| | if extracted_weight: |
| | if ref_name not in results: |
| | results[ref_name] = {res_type: {}} |
| | results[ref_name][res_type][model_name] = [extracted_weight] |
| |
|
| |
|
| | def _extract_weights_impl( |
| | model_name_a: str, |
| | gm_a: GraphModule, |
| | model_name_b: str, |
| | gm_b: GraphModule, |
| | base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None, |
| | ) -> NSResultsType: |
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl") |
| | matched_subgraph_pairs = get_matching_subgraph_pairs( |
| | gm_a, gm_b, base_name_to_sets_of_related_ops, |
| | unmatchable_types_map) |
| |
|
| | |
| | nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = [] |
| | nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = [] |
| | for match_name, match in matched_subgraph_pairs.items(): |
| | subgraph_a, subgraph_b = match |
| | nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name)) |
| | nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name)) |
| |
|
| | |
| | results: NSResultsType = {} |
| | _extract_weights_one_model( |
| | model_name_a, gm_a, nodes_and_names_to_instrument_a, results, |
| | op_to_type_to_weight_extraction_fn) |
| | _extract_weights_one_model( |
| | model_name_b, gm_b, nodes_and_names_to_instrument_b, results, |
| | op_to_type_to_weight_extraction_fn) |
| |
|
| | |
| | maybe_add_missing_fqns(results) |
| |
|
| | |
| | results = rekey_logger_info_on_node_name_of_model(results, model_name_b) |
| |
|
| | return results |
| |
|
| |
|
| | def extract_weights( |
| | model_name_a: str, |
| | model_a: nn.Module, |
| | model_name_b: str, |
| | model_b: nn.Module, |
| | base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None, |
| | ) -> NSResultsType: |
| | """ |
| | Extract weights from model A and model B, and return a comparison. |
| | |
| | Args: |
| | model_name_a: string name of model A to use in results |
| | model_a: model A |
| | model_name_b: string name of model B to use in results |
| | model_b: model B |
| | base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change |
| | unmatchable_types_map: optional override of unmatchable types, subject to change |
| | op_to_type_to_weight_extraction_fn: optional override of function which extracts weight |
| | from a type, subject to change |
| | |
| | Return: |
| | NSResultsType, containing the weight comparisons |
| | """ |
| |
|
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights") |
| | if base_name_to_sets_of_related_ops is None: |
| | base_name_to_sets_of_related_ops = \ |
| | get_base_name_to_sets_of_related_ops() |
| | type_a_related_to_b = \ |
| | get_type_a_related_to_b(base_name_to_sets_of_related_ops) |
| |
|
| | |
| | skipped_module_names: List[str] = [] |
| | skipped_module_classes: List[Callable] = [] |
| | tracer_a = NSTracer(skipped_module_names, skipped_module_classes) |
| | tracer_b = NSTracer(skipped_module_names, skipped_module_classes) |
| | gm_a = GraphModule(model_a, tracer_a.trace(model_a)) |
| | if hasattr(model_a, '_node_name_to_scope'): |
| | gm_a._node_name_to_scope = model_a._node_name_to_scope |
| | gm_b = GraphModule(model_b, tracer_b.trace(model_b)) |
| | if hasattr(model_b, '_node_name_to_scope'): |
| | gm_b._node_name_to_scope = model_b._node_name_to_scope |
| | return _extract_weights_impl( |
| | model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops, |
| | unmatchable_types_map, op_to_type_to_weight_extraction_fn) |
| |
|
| |
|
| | def _add_loggers_one_model( |
| | model_name: str, |
| | model: GraphModule, |
| | nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]], |
| | nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]], |
| | logger_cls: Callable, |
| | ) -> nn.Module: |
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model") |
| |
|
| | |
| | |
| | node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {} |
| | node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {} |
| | for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs: |
| | node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type) |
| | for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs: |
| | node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type) |
| |
|
| | model = add_loggers_to_model( |
| | model, node_to_instrument_inputs_to_ref_name, |
| | node_to_instrument_outputs_to_ref_name, logger_cls, model_name) |
| | return model |
| |
|
| |
|
| | def _add_loggers_impl( |
| | name_a: str, |
| | gm_a: GraphModule, |
| | name_b: str, |
| | gm_b: GraphModule, |
| | logger_cls: Callable, |
| | should_log_inputs: bool, |
| | base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | ) -> Tuple[nn.Module, nn.Module]: |
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl") |
| | matched_subgraph_pairs = get_matching_subgraph_pairs( |
| | gm_a, gm_b, |
| | base_name_to_sets_of_related_ops, unmatchable_types_map) |
| | nodes_and_names_to_instrument_inputs_a = [] |
| | nodes_and_names_to_instrument_inputs_b = [] |
| | nodes_and_names_to_instrument_outputs_a = [] |
| | nodes_and_names_to_instrument_outputs_b = [] |
| | for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items(): |
| | ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a) |
| | ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b) |
| | |
| | |
| | if should_log_inputs: |
| | nodes_and_names_to_instrument_inputs_a.append( |
| | (subgraph_a.start_node, match_name, ref_node_type_a)) |
| | nodes_and_names_to_instrument_inputs_b.append( |
| | (subgraph_b.start_node, match_name, ref_node_type_b)) |
| | |
| | |
| | nodes_and_names_to_instrument_outputs_a.append( |
| | (subgraph_a.end_node, match_name, ref_node_type_a)) |
| | nodes_and_names_to_instrument_outputs_b.append( |
| | (subgraph_b.end_node, match_name, ref_node_type_b)) |
| |
|
| | new_model_a = _add_loggers_one_model( |
| | name_a, gm_a, nodes_and_names_to_instrument_inputs_a, |
| | nodes_and_names_to_instrument_outputs_a, logger_cls) |
| | new_model_b = _add_loggers_one_model( |
| | name_b, gm_b, nodes_and_names_to_instrument_inputs_b, |
| | nodes_and_names_to_instrument_outputs_b, logger_cls) |
| | return (new_model_a, new_model_b) |
| |
|
| |
|
| | def add_loggers( |
| | name_a: str, |
| | model_a: nn.Module, |
| | name_b: str, |
| | model_b: nn.Module, |
| | logger_cls: Callable, |
| | should_log_inputs : bool = False, |
| | base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | ) -> Tuple[nn.Module, nn.Module]: |
| | """ |
| | Instrument model A and model B with loggers. |
| | |
| | Args: |
| | model_name_a: string name of model A to use in results |
| | model_a: model A |
| | model_name_b: string name of model B to use in results |
| | model_b: model B |
| | logger_cls: class of Logger to use |
| | base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change |
| | unmatchable_types_map: optional override of unmatchable types, subject to change |
| | |
| | Return: |
| | Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace. |
| | """ |
| |
|
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers") |
| | |
| | skipped_module_names: List[str] = [] |
| | skipped_module_classes: List[Callable] = [] |
| | tracer_a = NSTracer(skipped_module_names, skipped_module_classes) |
| | tracer_b = NSTracer(skipped_module_names, skipped_module_classes) |
| | gm_a = GraphModule(model_a, tracer_a.trace(model_a)) |
| | if hasattr(model_a, '_node_name_to_scope'): |
| | gm_a._node_name_to_scope = model_a._node_name_to_scope |
| | gm_b = GraphModule(model_b, tracer_b.trace(model_b)) |
| | if hasattr(model_b, '_node_name_to_scope'): |
| | gm_b._node_name_to_scope = model_b._node_name_to_scope |
| | return _add_loggers_impl( |
| | name_a, gm_a, name_b, gm_b, logger_cls, |
| | should_log_inputs=should_log_inputs, |
| | base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, |
| | unmatchable_types_map=unmatchable_types_map) |
| |
|
| |
|
| | def _extract_logger_info_one_model( |
| | model: nn.Module, |
| | results: NSResultsType, |
| | logger_cls: Callable, |
| | ) -> None: |
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model") |
| | for gm_name, mod in model.named_modules(): |
| | |
| | is_logger = ( |
| | isinstance(mod, logger_cls) |
| | or ( |
| | isinstance(mod, torch.jit.RecursiveScriptModule) |
| | and mod.original_name == 'OutputLogger' |
| | ) |
| | ) |
| | if is_logger: |
| | key = mod.ref_name |
| | if key not in results: |
| | results[key] = {} |
| | assert mod.model_name not in results[key], \ |
| | f"{mod.model_name} is already present in results" |
| | if mod.results_type not in results[key]: |
| | results[key][mod.results_type] = {} |
| | if mod.model_name not in results[key][mod.results_type]: |
| | results[key][mod.results_type][mod.model_name] = [] |
| | stats_to_use = mod.stats |
| | if len(mod.stats_rnn) > 0: |
| | stats_to_use = mod.stats_rnn |
| | results[key][mod.results_type][mod.model_name].append({ |
| | 'type': mod.results_type, |
| | 'values': stats_to_use, |
| | 'ref_node_name': mod.ref_node_name, |
| | 'ref_node_target_type': mod.ref_node_target_type, |
| | 'prev_node_name': mod.prev_node_name, |
| | 'prev_node_target_type': mod.prev_node_target_type, |
| | 'index_within_arg': mod.index_within_arg, |
| | 'index_of_arg': mod.index_of_arg, |
| | 'fqn': mod.fqn, |
| | }) |
| | |
| | results[key][mod.results_type][mod.model_name].sort( |
| | key=lambda res: |
| | f"{res['index_of_arg']}:{res['index_within_arg']}" |
| | ) |
| |
|
| |
|
| | |
| | |
| | def extract_logger_info( |
| | model_a: nn.Module, |
| | model_b: nn.Module, |
| | logger_cls: Callable, |
| | model_name_to_use_for_layer_names: str, |
| | ) -> NSResultsType: |
| | """ |
| | Traverse all loggers in `model_a` and `model_b`, and extract the logged |
| | information. |
| | |
| | Args: |
| | model_a: model A |
| | model_b: model B |
| | logger_cls: class of Logger to use |
| | model_name_to_use_for_layer_names: string name of model to use for |
| | layer names in the output |
| | |
| | Return: |
| | NSResultsType, containing the logged comparisons |
| | """ |
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info") |
| | results: NSResultsType = {} |
| | for model in (model_a, model_b): |
| | _extract_logger_info_one_model(model, results, logger_cls) |
| | |
| | maybe_add_missing_fqns(results) |
| | |
| | results = rekey_logger_info_on_node_name_of_model( |
| | results, model_name_to_use_for_layer_names) |
| | return results |
| |
|
| |
|
| | def _add_shadow_loggers_impl( |
| | name_a: str, |
| | gm_a: GraphModule, |
| | name_b: str, |
| | gm_b: GraphModule, |
| | logger_cls: Callable, |
| | should_log_inputs: bool, |
| | base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | ) -> nn.Module: |
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl") |
| | matched_subgraph_pairs = get_matching_subgraph_pairs( |
| | gm_a, gm_b, base_name_to_sets_of_related_ops, |
| | unmatchable_types_map) |
| | gm_a_shadows_b = create_a_shadows_b( |
| | name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls, |
| | should_log_inputs=should_log_inputs, |
| | node_type_to_io_type_map=node_type_to_io_type_map) |
| | return gm_a_shadows_b |
| |
|
| |
|
| | def add_shadow_loggers( |
| | name_a: str, |
| | model_a: nn.Module, |
| | name_b: str, |
| | model_b: nn.Module, |
| | logger_cls: Callable, |
| | should_log_inputs: bool = False, |
| | base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, |
| | ) -> nn.Module: |
| | """ |
| | Instrument model A and model B with shadow loggers. |
| | |
| | Args: |
| | model_name_a: string name of model A to use in results |
| | model_a: model A |
| | model_name_b: string name of model B to use in results |
| | model_b: model B |
| | logger_cls: class of Logger to use |
| | should_log_inputs: whether to log inputs |
| | base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change |
| | unmatchable_types_map: optional override of unmatchable types, subject to change |
| | """ |
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers") |
| | |
| | skipped_module_names: List[str] = [] |
| | skipped_module_classes: List[Callable] = [] |
| | tracer_a = NSTracer(skipped_module_names, skipped_module_classes) |
| | tracer_b = NSTracer(skipped_module_names, skipped_module_classes) |
| | gm_a = GraphModule(model_a, tracer_a.trace(model_a)) |
| | if hasattr(model_a, '_node_name_to_scope'): |
| | gm_a._node_name_to_scope = model_a._node_name_to_scope |
| | gm_b = GraphModule(model_b, tracer_b.trace(model_b)) |
| | if hasattr(model_b, '_node_name_to_scope'): |
| | gm_b._node_name_to_scope = model_b._node_name_to_scope |
| | return _add_shadow_loggers_impl( |
| | name_a, gm_a, name_b, gm_b, logger_cls, |
| | should_log_inputs=should_log_inputs, |
| | base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops, |
| | node_type_to_io_type_map=node_type_to_io_type_map, |
| | unmatchable_types_map=unmatchable_types_map) |
| |
|
| |
|
| | def extract_shadow_logger_info( |
| | model_a_shadows_b: nn.Module, |
| | logger_cls: Callable, |
| | model_name_to_use_for_layer_names: str, |
| | ) -> NSResultsType: |
| | """ |
| | Traverse all loggers in a shadow model, and extract the logged |
| | information. |
| | |
| | Args: |
| | model_a_shadows_b: shadow model |
| | logger_cls: class of Logger to use |
| | model_name_to_use_for_layer_names: string name of model to use for |
| | layer names in the output |
| | |
| | Return: |
| | NSResultsType, containing the logged comparisons |
| | """ |
| | torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info") |
| | results: NSResultsType = collections.defaultdict(dict) |
| | _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls) |
| | |
| | maybe_add_missing_fqns(results) |
| | |
| | results = rekey_logger_info_on_node_name_of_model( |
| | results, model_name_to_use_for_layer_names) |
| | return dict(results) |
| |
|
| |
|
| | def extend_logger_results_with_comparison( |
| | results: NSResultsType, |
| | model_name_1: str, |
| | model_name_2: str, |
| | comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], |
| | comparison_name: str, |
| | ) -> None: |
| | """ |
| | Compares the logged values from `model_name_2` against the corresponding |
| | values in `model_name_1`, using `comparison_fn`. Records the result |
| | in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace. |
| | |
| | Args: |
| | results: the result data structure from `extract_logger_info` or |
| | `extract_shadow_logger_info`. |
| | model_name_1: string name of model 1 |
| | model_name_2: string name of model 2 |
| | comparison_fn: function to compare two Tensors |
| | model_name_to_use_for_layer_names: string name of model to use for |
| | layer names in the output |
| | """ |
| | for _, results_type_to_results in results.items(): |
| | for _, model_name_to_results in results_type_to_results.items(): |
| | assert model_name_1 in model_name_to_results, \ |
| | f"{model_name_1} not found in results" |
| | assert model_name_2 in model_name_to_results, \ |
| | f"{model_name_2} not found in results" |
| |
|
| | results_1 = model_name_to_results[model_name_1] |
| | results_2 = model_name_to_results[model_name_2] |
| |
|
| | for result_2 in results_2: |
| | index_within_arg_2 = result_2['index_within_arg'] |
| | index_of_arg_2 = result_2['index_of_arg'] |
| | |
| | result_1 = None |
| | for cur_result_1 in results_1: |
| | index_within_arg_1 = cur_result_1['index_within_arg'] |
| | index_of_arg_1 = cur_result_1['index_of_arg'] |
| | if ( |
| | (index_within_arg_1 == index_within_arg_2) and |
| | (index_of_arg_1 == index_of_arg_2) |
| | ): |
| | result_1 = cur_result_1 |
| | break |
| | assert result_1 is not None |
| |
|
| | values_1 = result_1['values'] |
| | values_2 = result_2['values'] |
| | result_2[comparison_name] = [] |
| | for value_1, value_2 in zip(values_1, values_2): |
| | comparison_result = comparison_fn(value_1, value_2) |
| | result_2[comparison_name].append(comparison_result) |
| |
|