| import functools |
| import torch.nn as nn |
| from torch.utils.checkpoint import checkpoint |
|
|
| from transformers.models.mistral.modeling_mistral import MistralDecoderLayer |
| from transformers.utils import logging |
|
|
| from .helpers import GatedCrossAttentionBlock |
| from .utils import getattr_recursive, setattr_recursive |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class FlamingoLayer(nn.Module): |
| """ |
| FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer. |
| """ |
|
|
| def __init__( |
| self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False |
| ): |
| super().__init__() |
| self.gated_cross_attn_layer = gated_cross_attn_layer |
| self.decoder_layer = decoder_layer |
| self.vis_x = None |
| self.media_locations = None |
| if self.gated_cross_attn_layer is not None: |
| self.gated_cross_attn_layer._use_gradient_checkpointing = ( |
| gradient_checkpointing |
| ) |
| self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing |
| self._use_gradient_checkpointing = gradient_checkpointing |
| if self._use_gradient_checkpointing: |
| self.gradient_checkpointing_enable() |
|
|
| def is_conditioned(self) -> bool: |
| """Check whether the layer is conditioned.""" |
| return self.vis_x is not None and self.media_locations is not None |
|
|
| |
| def condition_vis_x(self, vis_x): |
| self.vis_x = vis_x |
|
|
| def condition_media_locations(self, media_locations): |
| self.media_locations = media_locations |
|
|
| def condition_use_cached_media(self, use_cached_media): |
| self.use_cached_media = use_cached_media |
|
|
| def forward( |
| self, |
| lang_x, |
| attention_mask=None, |
| **decoder_layer_kwargs, |
| ): |
| |
| if self.gated_cross_attn_layer is not None: |
| if self.vis_x is None: |
| raise ValueError("vis_x must be conditioned before forward pass") |
|
|
| if self.media_locations is None: |
| raise ValueError( |
| "media_locations must be conditioned before forward pass" |
| ) |
|
|
| lang_x = self.gated_cross_attn_layer( |
| lang_x, |
| self.vis_x, |
| media_locations=self.media_locations, |
| use_cached_media=self.use_cached_media, |
| ) |
|
|
| |
| if ( |
| self._use_gradient_checkpointing |
| and self.training |
| and isinstance(self.decoder_layer, MistralDecoderLayer) |
| ): |
| if ( |
| "use_cache" in decoder_layer_kwargs |
| and decoder_layer_kwargs["use_cache"] is True |
| ): |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing." |
| " Setting `use_cache=False`..." |
| ) |
| decoder_layer_kwargs["use_cache"] = False |
| |
| |
| |
| |
|
|
| |
| lang_x = self._gradient_checkpointing_func( |
| self.decoder_layer.__call__, |
| lang_x, |
| attention_mask, |
| decoder_layer_kwargs["position_ids"], |
| decoder_layer_kwargs["past_key_value"], |
| decoder_layer_kwargs["output_attentions"], |
| decoder_layer_kwargs["use_cache"], |
| ) |
| else: |
| lang_x = self.decoder_layer( |
| lang_x, attention_mask=attention_mask, **decoder_layer_kwargs |
| ) |
| return lang_x |
|
|
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
| """ |
| Activates gradient checkpointing for the current model. |
| |
| Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint |
| activations". |
| |
| We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of |
| the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 |
| |
| Args: |
| gradient_checkpointing_kwargs (dict, *optional*): |
| Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. |
| """ |
| if gradient_checkpointing_kwargs is None: |
| gradient_checkpointing_kwargs = {} |
|
|
| gradient_checkpointing_func = functools.partial( |
| checkpoint, **gradient_checkpointing_kwargs |
| ) |
|
|
| self._gradient_checkpointing_func = gradient_checkpointing_func |
|
|
| if getattr(self, "_hf_peft_config_loaded", False): |
| |
| |
| |
| |
| self.enable_input_require_grads() |
|
|
|
|
| class FlamingoLMMixin(nn.Module): |
| """ |
| Mixin to add cross-attention layers to a language model. |
| """ |
|
|
| def set_decoder_layers_attr_name(self, decoder_layers_attr_name): |
| self.decoder_layers_attr_name = decoder_layers_attr_name |
|
|
| def _get_decoder_layers(self): |
| return getattr_recursive(self, self.decoder_layers_attr_name) |
|
|
| def _set_decoder_layers(self, value): |
| setattr_recursive(self, self.decoder_layers_attr_name, value) |
|
|
| def init_flamingo( |
| self, |
| media_token_id, |
| lang_hidden_size, |
| vis_hidden_size, |
| cross_attn_every_n_layers, |
| *, |
| enable_init_network_params=False, |
| initializer_range=0.02, |
| gradient_checkpointing=False, |
| ): |
| """ |
| Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. |
| """ |
| self.old_decoder_blocks = self._get_decoder_layers() |
| self.gated_cross_attn_layers = nn.ModuleList( |
| [ |
| ( |
| GatedCrossAttentionBlock( |
| dim=lang_hidden_size, |
| dim_visual=vis_hidden_size, |
| ff_mult=4, |
| enable_init_network_params=enable_init_network_params, |
| initializer_range=initializer_range, |
| gradient_checkpointing=gradient_checkpointing, |
| ) |
| if (layer_idx + 1) % cross_attn_every_n_layers == 0 |
| else None |
| ) |
| for layer_idx, _ in enumerate(self._get_decoder_layers()) |
| ] |
| ) |
| self.init_flamingo_layers(gradient_checkpointing) |
| self.media_token_id = media_token_id |
| self.initialized_flamingo = True |
| self._use_cached_vision_x = False |
| self.gradient_checkpointing = gradient_checkpointing |
|
|
| def init_flamingo_layers(self, gradient_checkpointing): |
| """ |
| Re initializes the FlamingoLayers. |
| Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks |
| """ |
| self._set_decoder_layers( |
| nn.ModuleList( |
| [ |
| FlamingoLayer( |
| gated_cross_attn_layer, decoder_layer, gradient_checkpointing |
| ) |
| for gated_cross_attn_layer, decoder_layer in zip( |
| self.gated_cross_attn_layers, self.old_decoder_blocks |
| ) |
| ] |
| ) |
| ) |
|
|
| def forward(self, input_ids, attention_mask, **kwargs): |
| """Condition the Flamingo layers on the media locations before forward()""" |
| if not self.initialized_flamingo: |
| raise ValueError( |
| "Flamingo layers are not initialized. Please call `init_flamingo`" |
| " first." |
| ) |
|
|
| media_locations = input_ids == self.media_token_id |
|
|
| |
| |
| |
| |
| |
| use_cached_media_locations = ( |
| self._use_cached_vision_x |
| and self.is_conditioned() |
| and not media_locations.any() |
| ) |
|
|
| for layer in self._get_decoder_layers(): |
| if not use_cached_media_locations: |
| layer.condition_media_locations(media_locations) |
| layer.condition_use_cached_media(use_cached_media_locations) |
|
|
| |
| |
| kwargs["input_ids"] = input_ids |
| kwargs["attention_mask"] = attention_mask |
|
|
| |
| if self.gradient_checkpointing and isinstance( |
| self.old_decoder_blocks[0], MistralDecoderLayer |
| ): |
| kwargs["use_cache"] = False |
| return super().forward(**kwargs) |
|
|
| def is_conditioned(self) -> bool: |
| """Check whether all decoder layers are already conditioned.""" |
| return all(l.is_conditioned() for l in self._get_decoder_layers()) |
|
|
| def clear_conditioned_layers(self): |
| for layer in self._get_decoder_layers(): |
| layer.condition_vis_x(None) |
| layer.condition_media_locations(None) |
| layer.condition_use_cached_media(None) |
|
|