Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from typing import Union, List, Dict, Optional | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, GPTJForCausalLM | |
| from transformers.generation_logits_process import ( | |
| LogitsProcessorList, | |
| NoBadWordsLogitsProcessor, | |
| NoRepeatNGramLogitsProcessor, | |
| ) | |
| from utils import ( | |
| NEGATIVE_INF, HALF_NEGATIVE_INF, | |
| logits_to_entropy, mask_pad | |
| ) | |
| from clipcap import ClipCap | |
| class Policy(nn.Module): | |
| def __init__(self, model_name, temperature, device, clipcap_path='', fix_gpt=False, | |
| use_transformer_mapper: bool = False, use_ptuning_v2: bool = False, | |
| prefix_length=10, clipcap_num_layers: int = 1, | |
| label_path: str = '', model_weight: str = 'None', use_label_prefix: bool = False): | |
| super().__init__() | |
| self.device = device | |
| self.model = ClipCap(model_name, device, | |
| model_path=clipcap_path, fix_gpt=fix_gpt, | |
| prefix_length=prefix_length, | |
| num_layers=clipcap_num_layers, | |
| label_path=label_path, model_weight=model_weight, | |
| use_transformer_mapper=use_transformer_mapper, | |
| use_ptuning_v2=use_ptuning_v2, | |
| use_label_prefix=use_label_prefix) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, pad_token="<|endoftext|>") | |
| self.model.gpt.config.pad_token_id = self.tokenizer.pad_token_id | |
| self.temperature = temperature | |
| def get_processor(self, no_repeat_ngram_size: int = 3): | |
| logits_processor = LogitsProcessorList() | |
| if no_repeat_ngram_size > 0: | |
| logits_processor.append(NoRepeatNGramLogitsProcessor(ngram_size=no_repeat_ngram_size)) | |
| ''' | |
| logits_processor.append(NoBadWordsLogitsProcessor([[self.tokenizer.pad_token_id]], | |
| self.tokenizer.pad_token_id)) | |
| ''' | |
| return logits_processor | |
| def sample(self, | |
| input_ids: torch.Tensor = None, | |
| features: torch.Tensor = None, | |
| attention_mask: torch.Tensor = None, | |
| labels: Optional[torch.Tensor] = None, | |
| max_len: int = 20, | |
| sample: bool = True, | |
| top_k: int = None, | |
| top_p: float = None, | |
| temperature: float = None, | |
| no_repeat_ngram_size: int = 0, | |
| invalidate_eos: bool = True, | |
| device = None) -> Dict[str, Union[torch.Tensor, List[str]]]: | |
| if device is None: | |
| device = self.device | |
| if temperature is None: | |
| temperature = self.temperature | |
| input_ids = input_ids.to(device) | |
| attention_mask = attention_mask.to(device) | |
| model_kwargs = {'attention_mask': attention_mask} | |
| batch_size, input_seq_len = input_ids.shape | |
| logits_processor = self.get_processor(no_repeat_ngram_size=no_repeat_ngram_size) | |
| logits_warper = self.model.gpt._get_logits_warper( | |
| top_k=top_k, top_p=top_p, temperature=temperature, num_beams=1 | |
| ) | |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=device) | |
| output_logprob = torch.zeros([batch_size, 0], device=device) | |
| eos_logprobs = torch.zeros([batch_size, 0], device=device) | |
| output_mask = torch.ones([batch_size, 0], dtype=torch.long, device=device) | |
| self.model.eval() | |
| with torch.no_grad(): | |
| for step in range(max_len): | |
| # prepare model inputs | |
| model_inputs = self.model.prepare_inputs_for_generation(input_ids, | |
| features=features, | |
| labels=labels, | |
| **model_kwargs) | |
| # forward pass to get next token | |
| outputs = self.model( | |
| **model_inputs, | |
| device=device | |
| ) | |
| # in the first decoding step, we want to use the 'real' last position for each sentence | |
| if step == 0: | |
| last_non_masked_idx = torch.sum(attention_mask, dim=1) - 1 | |
| next_token_logits = outputs.logits[range(batch_size), last_non_masked_idx, :] | |
| else: | |
| next_token_logits = outputs.logits[:, -1, :] | |
| negative_inf = HALF_NEGATIVE_INF if next_token_logits.dtype == torch.half else NEGATIVE_INF | |
| next_token_scores = logits_processor(input_ids, next_token_logits) | |
| if invalidate_eos: | |
| next_token_scores[:, self.tokenizer.eos_token_id] = negative_inf # no endoftext | |
| log_prob = F.log_softmax(next_token_scores, dim=-1) # authentic sampling distribution | |
| next_token_scores = logits_warper(input_ids, next_token_scores) | |
| if sample: | |
| # Temperature (higher temperature => more likely to sample low probability tokens) | |
| probs = F.softmax(next_token_scores, dim=-1) | |
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
| else: | |
| # Greedy decoding | |
| next_tokens = torch.argmax(next_token_scores, dim=-1) | |
| # finished sentences should have their next token be a padding token | |
| next_tokens = next_tokens * unfinished_sequences + self.tokenizer.pad_token_id * (1 - unfinished_sequences) | |
| # update output mask | |
| output_mask = torch.cat([output_mask, unfinished_sequences[:, None]], dim=-1) | |
| # update output log probability | |
| eos_logprob = log_prob[:, self.tokenizer.eos_token_id] | |
| eos_logprob = eos_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences) | |
| eos_logprobs = torch.cat([eos_logprobs, eos_logprob[:, None]], dim=-1) | |
| token_logprob = torch.gather(log_prob, 1, next_tokens[:, None]).squeeze(1) | |
| token_logprob = token_logprob * unfinished_sequences + negative_inf * (1 - unfinished_sequences) | |
| output_logprob = torch.cat([output_logprob, token_logprob[:, None]], dim=-1) | |
| # update generated ids, model inputs for next step | |
| input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
| model_kwargs = self.model.gpt._update_model_kwargs_for_generation( | |
| outputs, model_kwargs, is_encoder_decoder=self.model.gpt.config.is_encoder_decoder | |
| ) | |
| # if eos_token was found in one sentence, set sentence to finished | |
| unfinished_sequences = unfinished_sequences.mul((next_tokens != self.tokenizer.eos_token_id).long()) | |
| if unfinished_sequences.max() == 0: | |
| break | |
| response_ids = input_ids[:, input_seq_len:] | |
| response_text = [self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| for output in response_ids] | |
| prompt_ids = input_ids[:, :input_seq_len] | |
| prompts = [self.tokenizer.decode(query, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| for query in prompt_ids] | |
| eos_probs = eos_logprobs.exp() | |
| return { | |
| 'query/input_ids': prompt_ids, | |
| 'query/text': prompts, | |
| 'query/mask': attention_mask, | |
| 'response/input_ids': response_ids, | |
| 'response/text': response_text, | |
| 'response/mask': output_mask, | |
| 'response/log_prob': output_logprob, | |
| 'response/eos_prob': eos_probs, | |
| } | |
| def forward_pass(self, | |
| query_input_ids: torch.Tensor, | |
| query_mask: torch.Tensor, | |
| response_input_ids: torch.Tensor, | |
| response_mask: torch.Tensor, | |
| features: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None, | |
| invalidate_eos: bool = True, | |
| device = None): | |
| if device is None: | |
| device = self.device | |
| batch_size, query_seq_len = query_input_ids.shape | |
| input_ids = torch.cat([query_input_ids, response_input_ids], dim=-1) | |
| attention_mask = torch.cat([query_mask, response_mask], dim=-1) | |
| # forward pass to get next token | |
| outputs = self.model( | |
| input_ids, | |
| features, | |
| attention_mask, | |
| labels, | |
| device=device | |
| ) | |
| # get the first logit | |
| query_logits = outputs.logits[:, :query_seq_len, :] | |
| last_non_masked_idx = torch.sum(query_mask, dim=1) - 1 | |
| first_logits = query_logits[range(batch_size), last_non_masked_idx, :] | |
| # get the second to last logit | |
| response_logits = outputs.logits[:, query_seq_len:-1, :] | |
| logits = torch.cat([first_logits[:, None], response_logits], dim=1) | |
| negative_inf = HALF_NEGATIVE_INF if logits.dtype == torch.half else NEGATIVE_INF | |
| if invalidate_eos: | |
| logits[:, :, self.tokenizer.eos_token_id] = negative_inf # no endoftext | |
| log_prob = F.log_softmax(logits, dim=-1) | |
| output_logprob = torch.gather(log_prob, 2, response_input_ids[:, :, None]).squeeze(2) | |
| output_entropy = logits_to_entropy(logits) | |
| eos_prob = F.softmax(logits, dim=-1)[:, :, self.tokenizer.eos_token_id] | |
| pos_logit = torch.gather(logits, 2, response_input_ids[:, :, None]).squeeze(2) | |
| return { | |
| 'response/log_prob': mask_pad(output_logprob, response_mask), | |
| 'response/eos_prob': mask_pad(eos_prob, response_mask), | |
| 'response/entropy': mask_pad(output_entropy, response_mask), | |
| 'response/pos_logit': mask_pad(pos_logit, response_mask), | |
| 'response/logits': logits, | |
| } | |