| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| import warnings |
| from typing import Any, Dict, List, Optional, Union, Callable |
| import torch |
| from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList |
| from transformers.generation import validate_stopping_criteria, EosTokenCriteria |
| from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput |
| from transformers.utils import ModelOutput |
|
|
| ALL_CACHE_NAMES = [ |
| "past_key_values", |
| "cache_params", |
| "state", |
| "mems", |
| "past_buckets_states", |
| ] |
|
|
| class TSGenerationMixin(GenerationMixin): |
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| synced_gpus: Optional[bool] = None, |
| assistant_model: Optional["PreTrainedModel"] = None, |
| streamer: Optional["BaseStreamer"] = None, |
| negative_prompt_ids: Optional[torch.Tensor] = None, |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
| revin: Optional[bool] = True, |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
| if len(inputs.shape) != 2: |
| raise ValueError('Input shape must be: [batch_size, seq_len]') |
| if revin: |
| means = inputs.mean(dim=-1, keepdim=True) |
| stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5 |
| inputs = (inputs - means) / stdev |
| outputs = super().generate( |
| inputs=inputs, |
| generation_config=generation_config, |
| logits_processor=logits_processor, |
| stopping_criteria=stopping_criteria, |
| prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
| synced_gpus=synced_gpus, |
| assistant_model=assistant_model, |
| streamer=streamer, |
| negative_prompt_ids=negative_prompt_ids, |
| negative_prompt_attention_mask=negative_prompt_attention_mask, |
| **kwargs, |
| ) |
| if revin: |
| stdev = stdev.unsqueeze(1) |
| means = means.unsqueeze(1) |
| outputs = (outputs * stdev) + means |
| return outputs |
|
|
| def _sample( |
| self, |
| input_ids: torch.Tensor, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| max_length: Optional[int] = None, |
| pad_token_id: Optional[int] = None, |
| eos_token_id: Optional[Union[int, List[int]]] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_scores: Optional[bool] = None, |
| output_logits: Optional[bool] = None, |
| return_dict_in_generate: Optional[bool] = None, |
| synced_gpus: bool = False, |
| streamer: Optional["BaseStreamer"] = None, |
| **model_kwargs, |
| ) -> Union[GenerateNonBeamOutput, torch.Tensor]: |
| input_ids = input_ids.to(self.device) |
| batch_size, cur_len = input_ids.shape |
| |
| logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| if max_length is not None: |
| warnings.warn( |
| "`max_length` is deprecated in this function, use" |
| " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", |
| UserWarning, |
| ) |
| stopping_criteria = validate_stopping_criteria( |
| stopping_criteria, max_length) |
| pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |
| if eos_token_id is not None: |
| stopping_criteria.append( |
| EosTokenCriteria(eos_token_id=eos_token_id)) |
| else: |
| |
| eos_token_id = [ |
| criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") |
| ] |
| eos_token_id = eos_token_id[0] if eos_token_id else None |
| if eos_token_id is None and self.generation_config.eos_token_id is not None: |
| eos_token_id = self.generation_config.eos_token_id |
| stopping_criteria.append( |
| EosTokenCriteria(eos_token_id=eos_token_id)) |
|
|
| if isinstance(eos_token_id, int): |
| eos_token_id = [eos_token_id] |
| output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |
| output_attentions = ( |
| output_attentions if output_attentions is not None else self.generation_config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |
| ) |
| return_dict_in_generate = ( |
| return_dict_in_generate |
| if return_dict_in_generate is not None |
| else self.generation_config.return_dict_in_generate |
| ) |
|
|
| |
| raw_logits = () if (return_dict_in_generate and output_logits) else None |
| scores = () if (return_dict_in_generate and output_scores) else None |
| decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
| cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
| decoder_hidden_states = () if ( |
| return_dict_in_generate and output_hidden_states) else None |
|
|
| |
| if return_dict_in_generate and self.config.is_encoder_decoder: |
| encoder_attentions = model_kwargs["encoder_outputs"].get( |
| "attentions") if output_attentions else None |
| encoder_hidden_states = ( |
| model_kwargs["encoder_outputs"].get( |
| "hidden_states") if output_hidden_states else None |
| ) |
|
|
| |
| if "inputs_embeds" in model_kwargs: |
| cur_len = model_kwargs["inputs_embeds"].shape[1] |
| this_peer_finished = False |
| unfinished_sequences = torch.ones( |
| batch_size, dtype=torch.long, device=input_ids.device) |
| model_kwargs["cache_position"] = torch.arange( |
| cur_len, device=input_ids.device) |
| true_seq_len = (cur_len + self.config.input_token_len - 1) // self.config.input_token_len |
| model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:] |
| max_length = stopping_criteria.max_length |
| |
| generate_results = None |
| while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): |
| |
| model_inputs = self.prepare_inputs_for_generation( |
| input_ids, **model_kwargs) |
|
|
| input_length = input_ids.shape[1] |
|
|
| |
| outputs = self( |
| **model_inputs, |
| return_dict=True, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| max_output_length=max_length - input_length, |
| ) |
|
|
| if synced_gpus and this_peer_finished: |
| continue |
| next_token_logits = outputs.logits |
|
|
| |
| next_tokens_scores = logits_processor(input_ids, next_token_logits) |
|
|
| |
| if return_dict_in_generate: |
| if output_scores: |
| scores += (next_tokens_scores,) |
| if output_logits: |
| raw_logits += (next_token_logits,) |
| if output_attentions: |
| decoder_attentions += ( |
| (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( |
| outputs.attentions,) |
| ) |
| if self.config.is_encoder_decoder: |
| cross_attentions += (outputs.cross_attentions,) |
|
|
| if output_hidden_states: |
| decoder_hidden_states += ( |
| (outputs.decoder_hidden_states,) |
| if self.config.is_encoder_decoder |
| else (outputs.hidden_states,) |
| ) |
|
|
| |
| |
| next_tokens = next_tokens_scores |
|
|
| |
| if eos_token_id is not None: |
| if pad_token_id is None: |
| raise ValueError( |
| "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
| next_tokens = next_tokens * unfinished_sequences + \ |
| pad_token_id * (1 - unfinished_sequences) |
|
|
| |
| horizon_length = next_tokens.shape[-1] // self.config.input_token_len |
|
|
| past_key_values = model_kwargs.get("past_key_values") |
| if generate_results is None: |
| generate_results = next_tokens |
| else: |
| generate_results = torch.cat([generate_results, next_tokens], dim=-1) |
|
|
| |
| |
|
|
| selected_tokens = torch.quantile(next_tokens.float(), q=0.5, dim=1) |
| input_ids = torch.cat([input_ids, selected_tokens], dim=-1) |
| |
| if streamer is not None: |
| streamer.put(next_tokens.cpu()) |
| model_kwargs = self._update_model_kwargs_for_generation( |
| outputs, |
| model_kwargs, |
| horizon_length=horizon_length, |
| is_encoder_decoder=self.config.is_encoder_decoder, |
| ) |
| unfinished_sequences = unfinished_sequences & ~stopping_criteria( |
| input_ids, scores) |
| this_peer_finished = unfinished_sequences.max() == 0 |
|
|
| if input_ids.shape[-1] > max_length: |
| input_ids = input_ids[:, :max_length] |
|
|
| if streamer is not None: |
| streamer.end() |
|
|
| if return_dict_in_generate: |
| if self.config.is_encoder_decoder: |
| return GenerateEncoderDecoderOutput( |
| sequences=input_ids, |
| scores=scores, |
| logits=raw_logits, |
| encoder_attentions=encoder_attentions, |
| encoder_hidden_states=encoder_hidden_states, |
| decoder_attentions=decoder_attentions, |
| cross_attentions=cross_attentions, |
| decoder_hidden_states=decoder_hidden_states, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| else: |
| return GenerateDecoderOnlyOutput( |
| sequences=input_ids, |
| scores=scores, |
| logits=raw_logits, |
| attentions=decoder_attentions, |
| hidden_states=decoder_hidden_states, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| else: |
| return generate_results[:, :, :(max_length - cur_len)] |
|
|
| def _update_model_kwargs_for_generation( |
| self, |
| outputs: ModelOutput, |
| model_kwargs: Dict[str, Any], |
| horizon_length: int = 1, |
| is_encoder_decoder: bool = False, |
| standardize_cache_format: bool = False, |
| ) -> Dict[str, Any]: |
| |
| for possible_cache_name in ALL_CACHE_NAMES: |
| if possible_cache_name in outputs: |
| if possible_cache_name in ("past_buckets_states", "mems"): |
| cache_name = "past_key_values" |
| else: |
| cache_name = possible_cache_name |
| model_kwargs[cache_name] = getattr(outputs, possible_cache_name) |
| break |
|
|
| |
| if "token_type_ids" in model_kwargs: |
| token_type_ids = model_kwargs["token_type_ids"] |
| model_kwargs["token_type_ids"] = torch.cat( |
| [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) |
|
|
| if not is_encoder_decoder: |
| |
| if "attention_mask" in model_kwargs: |
| attention_mask = model_kwargs["attention_mask"] |
| model_kwargs["attention_mask"] = torch.cat( |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1 |
| ) |
| else: |
| |
| if "decoder_attention_mask" in model_kwargs: |
| decoder_attention_mask = model_kwargs["decoder_attention_mask"] |
| model_kwargs["decoder_attention_mask"] = torch.cat( |
| [decoder_attention_mask, decoder_attention_mask.new_ones( |
| (decoder_attention_mask.shape[0], horizon_length))], |
| dim=-1, |
| ) |
|
|
| if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: |
| model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length |
|
|
| |
| if hasattr(outputs, "hidden_states_for_mtp") and outputs.hidden_states_for_mtp is not None: |
| new_hs = outputs.hidden_states_for_mtp |
| if "full_hidden_states" in model_kwargs and model_kwargs["full_hidden_states"] is not None: |
| existing = model_kwargs["full_hidden_states"] |
| model_kwargs["full_hidden_states"] = torch.cat( |
| [existing.to(new_hs.device), new_hs], dim=1 |
| ) |
| else: |
| model_kwargs["full_hidden_states"] = new_hs |
|
|
| return model_kwargs |