Spaces:
Runtime error
Runtime error
| from torch import Tensor | |
| from torch.nn import Module | |
| from torch.utils.hooks import RemovableHandle | |
| from transformers import ViTForImageClassification | |
| from typing import Optional, Union | |
| def _add_hooks( | |
| model: ViTForImageClassification, get_hook: callable | |
| ) -> list[RemovableHandle]: | |
| """Adds a list of hooks to the model according to the get_hook function provided. | |
| Args: | |
| model (ViTForImageClassification): the ViT instance to add hooks to | |
| get_hook (callable): a function that takes an index and returns a hook | |
| Returns: | |
| a list of RemovableHandle instances | |
| """ | |
| return ( | |
| [model.vit.embeddings.patch_embeddings.register_forward_hook(get_hook(0))] | |
| + [ | |
| layer.register_forward_pre_hook(get_hook(i + 1)) | |
| for i, layer in enumerate(model.vit.encoder.layer) | |
| ] | |
| + [ | |
| model.vit.encoder.layer[-1].register_forward_hook( | |
| get_hook(len(model.vit.encoder.layer) + 1) | |
| ) | |
| ] | |
| ) | |
| def vit_getter( | |
| model: ViTForImageClassification, x: Tensor | |
| ) -> tuple[Tensor, list[Tensor]]: | |
| """A function that returns the logits and hidden states of the model. | |
| Args: | |
| model (ViTForImageClassification): the ViT instance to use for the forward pass | |
| x (Tensor): the input to the model | |
| Returns: | |
| a tuple of the model's logits and hidden states | |
| """ | |
| hidden_states_ = [] | |
| def get_hook(i: int) -> callable: | |
| def hook(_: Module, inputs: tuple, outputs: Optional[tuple] = None): | |
| if i == 0: | |
| hidden_states_.append(outputs) | |
| elif 1 <= i <= len(model.vit.encoder.layer): | |
| hidden_states_.append(inputs[0]) | |
| elif i == len(model.vit.encoder.layer) + 1: | |
| hidden_states_.append(outputs[0]) | |
| return hook | |
| handles = _add_hooks(model, get_hook) | |
| try: | |
| logits = model(x).logits | |
| finally: | |
| for handle in handles: | |
| handle.remove() | |
| return logits, hidden_states_ | |
| def vit_setter( | |
| model: ViTForImageClassification, x: Tensor, hidden_states: list[Optional[Tensor]] | |
| ) -> tuple[Tensor, list[Tensor]]: | |
| """A function that sets some of the model's hidden states and returns its (new) logits | |
| and hidden states after another forward pass. | |
| Args: | |
| model (ViTForImageClassification): the ViT instance to use for the forward pass | |
| x (Tensor): the input to the model | |
| hidden_states (list[Optional[Tensor]]): a list, with each element corresponding to | |
| a hidden state to set or None to calculate anew for that index | |
| Returns: | |
| a tuple of the model's logits and (new) hidden states | |
| """ | |
| hidden_states_ = [] | |
| def get_hook(i: int) -> callable: | |
| def hook( | |
| _: Module, inputs: tuple, outputs: Optional[tuple] = None | |
| ) -> Optional[Union[tuple, Tensor]]: | |
| if i == 0: | |
| if hidden_states[i] is not None: | |
| # print(hidden_states[i].shape) | |
| hidden_states_.append(hidden_states[i][:, 1:]) | |
| return hidden_states_[-1] | |
| else: | |
| hidden_states_.append(outputs) | |
| elif 1 <= i <= len(model.vit.encoder.layer): | |
| if hidden_states[i] is not None: | |
| hidden_states_.append(hidden_states[i]) | |
| return (hidden_states[i],) + inputs[1:] | |
| else: | |
| hidden_states_.append(inputs[0]) | |
| elif i == len(model.vit.encoder.layer) + 1: | |
| if hidden_states[i] is not None: | |
| hidden_states_.append(hidden_states[i]) | |
| return (hidden_states[i],) + outputs[1:] | |
| else: | |
| hidden_states_.append(outputs[0]) | |
| return hook | |
| handles = _add_hooks(model, get_hook) | |
| try: | |
| logits = model(x).logits | |
| finally: | |
| for handle in handles: | |
| handle.remove() | |
| return logits, hidden_states_ | |