| import random |
| import pdb |
| from einops import rearrange |
| from typing import List, Optional, Tuple, Union |
| import os |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
| from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss |
|
|
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
| import transformers.models.opt.modeling_opt as modeling_opt |
| from transformers.models.opt.modeling_opt\ |
| import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig |
| from transformers import ViTModel |
|
|
| try: |
| from transformers.models.opt.modeling_opt import _prepare_4d_causal_attention_mask |
| except: |
| _prepare_4d_causal_attention_mask = None |
|
|
| from .utils import exists, freeze_all_layers_, unfreeze_all_layers_ |
| from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler |
| from .configuration_flamingo import FlamingoConfig |
|
|
|
|
| class OPTLearnedPositionalEmbedding(nn.Embedding): |
| """ |
| This module learns positional embeddings up to a fixed maximum size. |
| """ |
|
|
| def __init__(self, num_embeddings: int, embedding_dim: int): |
| |
| |
| self.offset = 2 |
| super().__init__(num_embeddings + self.offset, embedding_dim) |
|
|
| def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): |
| """`input_ids_shape` is expected to be [bsz x seqlen].""" |
| attention_mask = attention_mask.long() |
|
|
| |
| positions = torch.cumsum(attention_mask, dim=1) |
| positions = (positions.type_as(attention_mask) * attention_mask).long() - 1 |
|
|
| |
| positions = positions[:, past_key_values_length:] |
|
|
| return super().forward(positions + self.offset) |
|
|
|
|
| class OPTDecoder(modeling_opt.OPTDecoder): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] |
| Args: |
| config: OPTConfig |
| embed_tokens (nn.Embedding): output embedding |
| """ |
|
|
| def __init__(self, config: OPTConfig): |
| OPTPreTrainedModel.__init__(self, config) |
| self.dropout = config.dropout |
| self.layerdrop = config.layerdrop |
| self.padding_idx = config.pad_token_id |
| self.max_target_positions = config.max_position_embeddings |
| self.vocab_size = config.vocab_size |
| self.media_token_id = config.media_token_id |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) |
| self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) |
|
|
| if config.word_embed_proj_dim != config.hidden_size: |
| self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) |
| else: |
| self.project_out = None |
|
|
| if config.word_embed_proj_dim != config.hidden_size: |
| self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) |
| else: |
| self.project_in = None |
|
|
| |
| |
| |
| if config.do_layer_norm_before and not config._remove_final_layer_norm: |
| self.final_layer_norm = nn.LayerNorm(config.hidden_size) |
| else: |
| self.final_layer_norm = None |
|
|
| dim_head = config.hidden_size // config.num_attention_heads |
| if not config.id_perceiver: |
| self.perceiver_resampler = PerceiverResampler( |
| dim=config.hidden_size, |
| depth=config.perceiver_depth, |
| dim_head=dim_head, |
| heads=config.num_attention_heads, |
| num_latents=config.perceiver_num_latents, |
| inp_dim=config.inp_dim, |
| ) |
| else: |
| if config.inp_dim is None: |
| self.perceiver_resampler = nn.Identity() |
| else: |
| self.perceiver_resampler = nn.Linear( |
| config.inp_dim, config.hidden_size, |
| bias=False) |
| self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) |
| self.gated_attn_layers = nn.ModuleList( |
| [GatedCrossAttentionBlock( |
| dim=config.hidden_size, dim_head=dim_head, heads=config.num_attention_heads, |
| only_attend_immediate_media=config.only_attend_immediate_media)\ |
| if not (ind % config.cross_attn_every) else None \ |
| for ind in range(config.num_hidden_layers)]) |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| |
| if not config.finetune_LM: |
| freeze_all_layers_(self) |
| unfreeze_all_layers_(self.perceiver_resampler) |
| [unfreeze_all_layers_(cross_attn) for cross_attn in self.gated_attn_layers if exists(cross_attn)] |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values=None, |
| image_embeds=None |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you |
| provide it. |
| Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| [What are attention masks?](../glossary#attention-mask) |
| head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): |
| Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
| shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and in the |
| cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those |
| that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of |
| all `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
| This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
| than the model's internal embedding lookup matrix. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| input_ids = input_ids.view(-1, input_shape[-1]) |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
| batch, device = input_ids.shape[0], input_ids.device |
|
|
| flamingo_mode = exists(pixel_values) or exists(image_embeds) |
|
|
| |
| if flamingo_mode: |
| media_locations = input_ids == self.media_token_id |
|
|
| assert not (exists(pixel_values) and exists(image_embeds)) |
| |
| |
| |
|
|
| if exists(pixel_values): |
| assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding' |
| if len(pixel_values.shape) == 4: |
| pixel_values = torch.unsqueeze(pixel_values, 1) |
| pixel_values = rearrange(pixel_values, 'b t ... -> (b t) ...') |
|
|
| with torch.no_grad(): |
| if getattr(self.img_encoder, 'vision_model', None) is not None: |
| image_outputs = self.img_encoder.vision_model( |
| pixel_values=pixel_values, |
| output_hidden_states=True, return_dict=True) |
| else: |
| image_outputs = self.img_encoder( |
| pixel_values=pixel_values, |
| output_hidden_states=True, return_dict=True) |
|
|
| image_embeds = image_outputs['last_hidden_state'] |
| image_embeds = rearrange(image_embeds, '(b t) ... -> b t ...', b = batch) |
|
|
| if exists(image_embeds): |
| image_embeds = self.perceiver_resampler(image_embeds) |
|
|
| past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| |
| if attention_mask is None: |
| attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) |
| pos_embeds = self.embed_positions(attention_mask, past_key_values_length) |
|
|
| if _prepare_4d_causal_attention_mask is None: |
| attention_mask = self._prepare_decoder_attention_mask( |
| attention_mask, input_shape, inputs_embeds, past_key_values_length |
| ) |
| else: |
| attention_mask = _prepare_4d_causal_attention_mask( |
| attention_mask, input_shape, inputs_embeds, past_key_values_length |
| ) |
|
|
| if self.project_in is not None: |
| inputs_embeds = self.project_in(inputs_embeds) |
|
|
| hidden_states = inputs_embeds + pos_embeds |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = () if use_cache else None |
|
|
| |
| for attn_mask, mask_name in zip([head_mask], ["head_mask"]): |
| if attn_mask is not None: |
| if attn_mask.size()[0] != (len(self.layers)): |
| raise ValueError( |
| f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" |
| f" {head_mask.size()[0]}." |
| ) |
|
|
| for idx, decoder_layer in enumerate(self.layers): |
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| dropout_probability = random.uniform(0, 1) |
| if self.training and (dropout_probability < self.layerdrop): |
| continue |
|
|
| past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
| flamingo_cross_attn = self.gated_attn_layers[idx] |
| if exists(flamingo_cross_attn) and exists(image_embeds): |
| hidden_states = flamingo_cross_attn( |
| hidden_states, |
| image_embeds, |
| media_locations = media_locations |
| ) |
|
|
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| if self.final_layer_norm is not None: |
| hidden_states = self.final_layer_norm(hidden_states) |
|
|
| if self.project_out is not None: |
| hidden_states = self.project_out(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| class OPTModel(modeling_opt.OPTModel): |
| def __init__(self, config: OPTConfig): |
| OPTPreTrainedModel.__init__(self, config) |
| self.decoder = OPTDecoder(config) |
|
|
| |
| self.post_init() |
|
|
|
|
| class OPTForCausalLM(modeling_opt.OPTForCausalLM): |
| _keys_to_ignore_on_load_missing = [r"lm_head.weight"] |
|
|
| def __init__(self, config): |
| OPTPreTrainedModel.__init__(self, config) |
| self.model = OPTModel(config) |
|
|
| |
| self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
|
|
| def set_default_if_nonexist(config, key, value): |
| if getattr(config, key, None) is None: |
| setattr(config, key, value) |
| return config |
|
|
|
|
| def setup_default_flamingo_configs(config): |
| set_default_if_nonexist(config, 'perceiver_depth', 2) |
| set_default_if_nonexist(config, 'perceiver_num_latents', 64) |
| set_default_if_nonexist(config, 'cross_attn_every', 3) |
| set_default_if_nonexist(config, 'only_attend_immediate_media', True) |
| set_default_if_nonexist(config, 'media_token_id', 50265) |
| set_default_if_nonexist(config, 'inp_dim', 768) |
| set_default_if_nonexist(config, 'finetune_LM', True) |
| set_default_if_nonexist(config, 'id_perceiver', False) |
| return config |
|
|
|
|
| class FlamingoForCausalLM(modeling_opt.OPTForCausalLM): |
| _keys_to_ignore_on_load_missing = [ |
| r"lm_head.weight", |
| ] |
| config_class = FlamingoConfig |
|
|
| def __init__(self, config): |
| OPTPreTrainedModel.__init__(self, config) |
| config = setup_default_flamingo_configs(config) |
| self.model = OPTModel(config) |
|
|
| |
| self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
| self.model.decoder.img_encoder = None |
| self.loss_fct = CrossEntropyLoss() |
| dino_model = ViTModel.from_pretrained("facebook/dino-vitb16") |
| self.setup_vis_encoder(dino_model) |
|
|
| def setup_vis_encoder(self, img_encoder): |
| self.model.decoder.img_encoder = img_encoder |
| freeze_all_layers_(img_encoder) |
| |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| *args, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you |
| provide it. |
| Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| [What are input IDs?](../glossary#input-ids) |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| [What are attention masks?](../glossary#attention-mask) |
| head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): |
| Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of |
| shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of |
| shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional |
| tensors are only required when the model is used as a decoder in a Sequence to Sequence model. |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and in the |
| cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those |
| that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of |
| all `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
| This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
| than the model's internal embedding lookup matrix. |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| Returns: |
| Example: |
| ```python |
| >>> from transformers import GPT2Tokenizer, OPTForCausalLM |
| >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") |
| >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") |
| >>> prompt = "Hey, are you consciours? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." |
| ```""" |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model.decoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| *args, **kwargs) |
|
|
| logits = self.lm_head(outputs[0]).contiguous() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class FlamingoForSequenceClassification(OPTPreTrainedModel): |
| _keys_to_ignore_on_load_missing = [ |
| r"score.weight", |
| ] |
| |
| def __init__(self, config: OPTConfig): |
| OPTPreTrainedModel.__init__(self, config) |
| config = setup_default_flamingo_configs(config) |
| self.num_labels = config.num_labels |
| self.model = OPTModel(config) |
|
|
| |
| self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) |
|
|
| |
| self.post_init() |
| self.model.decoder.img_encoder = None |
| self.loss_fct = CrossEntropyLoss() |
| dino_model = ViTModel.from_pretrained("facebook/dino-vitb16") |
| self.setup_vis_encoder(dino_model) |
|
|
| def setup_vis_encoder(self, img_encoder): |
| self.model.decoder.img_encoder = img_encoder |
| freeze_all_layers_(img_encoder) |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| *args, **kwargs) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model.decoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| *args, **kwargs) |
| |
| hidden_states = outputs[0] |
| logits = self.score(hidden_states) |
|
|
| if input_ids is not None: |
| batch_size, sequence_length = input_ids.shape[:2] |
| else: |
| batch_size, sequence_length = inputs_embeds.shape[:2] |
|
|
| if self.config.pad_token_id is None: |
| sequence_lengths = -1 |
| else: |
| if input_ids is not None: |
| |
| sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
| sequence_lengths = sequence_lengths % input_ids.shape[-1] |
| sequence_lengths = sequence_lengths.to(logits.device) |
| else: |
| sequence_lengths = -1 |
| |
| |
| |
| |
|
|
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
|
|
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(pooled_logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(pooled_logits, labels) |
|
|
| if not return_dict: |
| output = (pooled_logits,) + outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def get_input_embeddings(self): |
| return self.model.decoder.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.decoder.embed_tokens = value |
|
|