| """ |
| Misc Validators |
| ================= |
| Validators ensure compatibility between search methods, transformations, constraints, and goal functions. |
| |
| """ |
| import re |
|
|
| import textattack |
| from textattack.goal_functions import ( |
| InputReduction, |
| MinimizeBleu, |
| NonOverlappingOutput, |
| TargetedClassification, |
| UntargetedClassification, |
| ) |
|
|
| from . import logger |
|
|
| |
| MODELS_BY_GOAL_FUNCTIONS = { |
| (TargetedClassification, UntargetedClassification, InputReduction): [ |
| r"^textattack.models.helpers.lstm_for_classification.*", |
| r"^textattack.models.helpers.word_cnn_for_classification.*", |
| r"^transformers.modeling_\w*\.\w*ForSequenceClassification$", |
| ], |
| ( |
| NonOverlappingOutput, |
| MinimizeBleu, |
| ): [ |
| r"^textattack.models.helpers.t5_for_text_to_text.*", |
| ], |
| } |
|
|
| |
| |
| |
| MODELS_BY_GOAL_FUNCTION = {} |
| for goal_functions, matching_model_globs in MODELS_BY_GOAL_FUNCTIONS.items(): |
| for goal_function in goal_functions: |
| MODELS_BY_GOAL_FUNCTION[goal_function] = matching_model_globs |
|
|
|
|
| def validate_model_goal_function_compatibility(goal_function_class, model_class): |
| """Determines if ``model_class`` is task-compatible with |
| ``goal_function_class``. |
| |
| For example, a text-generative model like one intended for |
| translation or summarization would not be compatible with a goal |
| function that requires probability scores, like the |
| UntargetedGoalFunction. |
| """ |
| |
| try: |
| matching_model_globs = MODELS_BY_GOAL_FUNCTION[goal_function_class] |
| except KeyError: |
| matching_model_globs = [] |
| logger.warn(f"No entry found for goal function {goal_function_class}.") |
| |
| |
| model_module_path = ".".join((model_class.__module__, model_class.__name__)) |
| |
| for glob in matching_model_globs: |
| if re.match(glob, model_module_path): |
| logger.info( |
| f"Goal function {goal_function_class} compatible with model {model_class.__name__}." |
| ) |
| return |
| |
| for goal_functions, globs in MODELS_BY_GOAL_FUNCTIONS.items(): |
| for glob in globs: |
| if re.match(glob, model_module_path): |
| logger.warn( |
| f"Unknown if model {model_class.__name__} compatible with provided goal function {goal_function_class}." |
| f" Found match with other goal functions: {goal_functions}." |
| ) |
| return |
| |
|
|
| |
| |
| logger.warn( |
| f"Unknown if model of class {model_class} compatible with goal function {goal_function_class}." |
| ) |
|
|
|
|
| def validate_model_gradient_word_swap_compatibility(model): |
| """Determines if ``model`` is task-compatible with |
| ``GradientBasedWordSwap``. |
| |
| We can only take the gradient with respect to an individual word if |
| the model uses a word-based tokenizer. |
| """ |
| if isinstance(model, textattack.models.helpers.LSTMForClassification): |
| return True |
| else: |
| raise ValueError(f"Cannot perform GradientBasedWordSwap on model {model}.") |
|
|
|
|
| def transformation_consists_of(transformation, transformation_classes): |
| """Determines if ``transformation`` is or consists only of instances of a |
| class in ``transformation_classes``""" |
| from textattack.transformations import CompositeTransformation |
|
|
| if isinstance(transformation, CompositeTransformation): |
| for t in transformation.transformations: |
| if not transformation_consists_of(t, transformation_classes): |
| return False |
| return True |
| else: |
| for transformation_class in transformation_classes: |
| if isinstance(transformation, transformation_class): |
| return True |
| return False |
|
|
|
|
| def transformation_consists_of_word_swaps(transformation): |
| """Determines if ``transformation`` is a word swap or consists of only word |
| swaps.""" |
| from textattack.transformations import WordSwap, WordSwapGradientBased |
|
|
| return transformation_consists_of(transformation, [WordSwap, WordSwapGradientBased]) |
|
|
|
|
| def transformation_consists_of_word_swaps_and_deletions(transformation): |
| """Determines if ``transformation`` is a word swap or consists of only word |
| swaps and deletions.""" |
| from textattack.transformations import WordDeletion, WordSwap, WordSwapGradientBased |
|
|
| return transformation_consists_of( |
| transformation, [WordDeletion, WordSwap, WordSwapGradientBased] |
| ) |
|
|