Aurora / ts_generation_mixin.py
ccloud0525
feat: "first commit"
b40a476
import os
from typing import Any, Dict, List, Optional, Union, Callable
import torch
from transformers import BertTokenizer
from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.utils import GenerationConfig, GenerateOutput
from transformers.utils import ModelOutput
class TSGenerationMixin(GenerationMixin):
tokenizer = BertTokenizer.from_pretrained(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert_config'), local_files_only=True)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
text_inputs=None,
text_input_ids: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
text_token_type_ids: Optional[torch.Tensor] = None,
vision_inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
revin: Optional[bool] = True,
num_samples: Optional[int] = 1,
max_output_length: Optional[int] = 96,
inference_token_len: Optional[int] = None,
max_text_token_length: Optional[int] = 125,
**kwargs,
) -> Union[GenerateOutput, torch.Tensor]:
if len(inputs.shape) != 2:
raise ValueError('Input shape must be: [batch_size, seq_len]')
if revin:
means = inputs.mean(dim=-1, keepdim=True)
stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5
inputs = (inputs - means) / stdev
if text_inputs is not None:
tokenized_text = self._tokenize(text_inputs, max_length=max_text_token_length)
text_input_ids = tokenized_text['input_ids'].squeeze(0)
text_attention_mask = tokenized_text['attention_mask'].squeeze(0)
text_token_type_ids = tokenized_text.get('token_type_ids', torch.zeros_like(text_input_ids)).squeeze(0)
model_inputs = self.prepare_inputs_for_generation(
inputs,
text_input_ids=text_input_ids,
text_attention_mask=text_attention_mask,
text_token_type_ids=text_token_type_ids,
vision_inputs=vision_inputs,
generation_config=generation_config,
max_output_length=max_output_length,
inference_token_len=inference_token_len,
**kwargs
)
outputs = self(**model_inputs, return_dict=True, revin=False, num_samples=num_samples)
predictions = outputs.logits
if revin:
stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1)
means = means.unsqueeze(1).repeat(1, num_samples, 1)
predictions = (predictions * stdev) + means
return predictions
def prepare_inputs_for_generation(
self,
inputs: torch.Tensor,
text_input_ids: Optional[torch.Tensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
text_token_type_ids: Optional[torch.Tensor] = None,
vision_inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
max_output_length: Optional[int] = None,
inference_token_len: Optional[int] = None,
**kwargs
):
return {
"input_ids": inputs,
"text_input_ids": text_input_ids,
"text_attention_mask": text_attention_mask,
"text_token_type_ids": text_token_type_ids,
"vision_ids": vision_inputs,
"max_output_length": max_output_length,
"inference_token_len": inference_token_len,
**kwargs
}
def _tokenize(self, texts, max_length):
return self.tokenizer(
texts,
padding='max_length',
truncation=True,
max_length=max_length,
return_tensors="pt"
)
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
horizon_length: int = 1,
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
return model_kwargs