import os from typing import Any, Dict, List, Optional, Union, Callable import torch from transformers import BertTokenizer from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList from transformers.generation.utils import GenerationConfig, GenerateOutput from transformers.utils import ModelOutput class TSGenerationMixin(GenerationMixin): tokenizer = BertTokenizer.from_pretrained(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert_config'), local_files_only=True) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, text_inputs=None, text_input_ids: Optional[torch.Tensor] = None, text_attention_mask: Optional[torch.Tensor] = None, text_token_type_ids: Optional[torch.Tensor] = None, vision_inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, revin: Optional[bool] = True, num_samples: Optional[int] = 1, max_output_length: Optional[int] = 96, inference_token_len: Optional[int] = None, max_text_token_length: Optional[int] = 125, **kwargs, ) -> Union[GenerateOutput, torch.Tensor]: if len(inputs.shape) != 2: raise ValueError('Input shape must be: [batch_size, seq_len]') if revin: means = inputs.mean(dim=-1, keepdim=True) stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5 inputs = (inputs - means) / stdev if text_inputs is not None: tokenized_text = self._tokenize(text_inputs, max_length=max_text_token_length) text_input_ids = tokenized_text['input_ids'].squeeze(0) text_attention_mask = tokenized_text['attention_mask'].squeeze(0) text_token_type_ids = tokenized_text.get('token_type_ids', torch.zeros_like(text_input_ids)).squeeze(0) model_inputs = self.prepare_inputs_for_generation( inputs, text_input_ids=text_input_ids, text_attention_mask=text_attention_mask, text_token_type_ids=text_token_type_ids, vision_inputs=vision_inputs, generation_config=generation_config, max_output_length=max_output_length, inference_token_len=inference_token_len, **kwargs ) outputs = self(**model_inputs, return_dict=True, revin=False, num_samples=num_samples) predictions = outputs.logits if revin: stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1) means = means.unsqueeze(1).repeat(1, num_samples, 1) predictions = (predictions * stdev) + means return predictions def prepare_inputs_for_generation( self, inputs: torch.Tensor, text_input_ids: Optional[torch.Tensor] = None, text_attention_mask: Optional[torch.Tensor] = None, text_token_type_ids: Optional[torch.Tensor] = None, vision_inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, max_output_length: Optional[int] = None, inference_token_len: Optional[int] = None, **kwargs ): return { "input_ids": inputs, "text_input_ids": text_input_ids, "text_attention_mask": text_attention_mask, "text_token_type_ids": text_token_type_ids, "vision_ids": vision_inputs, "max_output_length": max_output_length, "inference_token_len": inference_token_len, **kwargs } def _tokenize(self, texts, max_length): return self.tokenizer( texts, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt" ) def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], horizon_length: int = 1, is_encoder_decoder: bool = False, standardize_cache_format: bool = False, ) -> Dict[str, Any]: return model_kwargs