Spaces:
Running
Running
| import torch.nn as nn | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from lm_steer.utils import set_seed | |
| from .model_utils import find_max_subspans | |
| punctuations = [ | |
| '!', '"', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', | |
| # '/', '#', | |
| ':', ';', '<', '=', '>', '?', '@', | |
| '[', '\\', ']', '^', '_', '`', | |
| '{', '|', '}', '~', | |
| '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', 'µ', '¶', '·', | |
| '¸', '¹', 'º', '»', '¼', '½', '¾', | |
| '\n', ' ', | |
| ] | |
| class LMSteerBase(nn.Module): | |
| def evidence_words(self, prompt, comparing_steer_values, | |
| truncation_length=1024, max_segments=4, max_length=10): | |
| if isinstance(comparing_steer_values, list): | |
| comparing_steer_values = \ | |
| torch.Tensor(comparing_steer_values).to(self.device) | |
| if (comparing_steer_values[0] - comparing_steer_values[1]).abs().sum()\ | |
| <= 0.2: | |
| return [(prompt, None)] | |
| tokenized = self.tokenizer( | |
| prompt, return_tensors="pt", | |
| max_length=truncation_length, truncation=True) | |
| input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device) | |
| input_ids = input_ids.expand(2, -1) | |
| attention_mask = torch.LongTensor(tokenized["attention_mask"]).to( | |
| self.device) | |
| attention_mask = attention_mask.expand(2, -1) | |
| self.steer.set_value(comparing_steer_values) | |
| with torch.no_grad(): | |
| output = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=input_ids) | |
| length = input_ids.shape[1] | |
| loss_token = F.cross_entropy( | |
| output.logits[:, :-1].reshape((2)*(length-1), -1), | |
| input_ids[:, 1:].reshape(-1), | |
| reduction="none" | |
| ) | |
| loss_token = loss_token.reshape(2, length - 1) | |
| token_evidence = (- loss_token[0] + loss_token[1]) | |
| tokens = input_ids[0] | |
| evidence_segments = find_max_subspans( | |
| token_evidence.cpu().numpy().tolist(), max_segments, max_length)[0] | |
| evidence_segments = [ | |
| (_seg[0]+1, _seg[1]+1) for _seg in evidence_segments] | |
| start = 0 | |
| output = [] | |
| if len(evidence_segments) > 0: | |
| for _segment in evidence_segments: | |
| if _segment[0] > start: | |
| output.append(( | |
| self.tokenizer.decode(tokens[start: _segment[0]]), | |
| None | |
| )) | |
| output.append(( | |
| self.tokenizer.decode(tokens[_segment[0]: _segment[1]]), | |
| "evidence" | |
| )) | |
| start = _segment[1] | |
| length = tokens.shape[-1] | |
| if _segment[1] < length: | |
| output.append(( | |
| self.tokenizer.decode(tokens[_segment[1]: length]), | |
| None | |
| )) | |
| else: | |
| output = [(prompt, None)] | |
| return output, token_evidence.tolist() | |
| def steer_analysis(self, prompt, steer_dim, min_value=-3, max_value=3, | |
| bins=7): | |
| tokenized = self.tokenizer(prompt) | |
| input_ids = torch.LongTensor(tokenized["input_ids"]).to(self.device) | |
| input_ids = input_ids.expand(bins + 1, -1) | |
| attention_mask = torch.LongTensor(tokenized["attention_mask"]).to( | |
| self.device) | |
| attention_mask = attention_mask.expand(bins + 1, -1) | |
| steer_values = torch.zeros(bins+1, self.num_steers).to(self.device) | |
| for bin_i in range(bins): | |
| steer_values[bin_i, steer_dim] = ( | |
| min_value + (max_value - min_value) / (bins - 1) * bin_i | |
| ) | |
| self.steer.set_value(steer_values) | |
| with torch.no_grad(): | |
| output = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=input_ids) | |
| length = input_ids.shape[1] | |
| loss_token = F.cross_entropy( | |
| output.logits[:, :-1].reshape((bins+1)*(length-1), -1), | |
| input_ids[:, 1:].reshape(-1), | |
| reduction="none" | |
| ) | |
| loss_token = loss_token.reshape(bins + 1, length - 1) | |
| loss = loss_token.mean(-1)[:-1] | |
| dist = ((- loss + loss.mean()) * 10).softmax(0) | |
| dist_list = list(zip( | |
| [ | |
| min_value + (max_value - min_value) / (bins - 1) * bin_i | |
| for bin_i in range(bins) | |
| ], | |
| dist.tolist(), | |
| )) | |
| best_guess = loss.argmin(0) | |
| best_guess_value = min_value + \ | |
| (max_value - min_value) / (bins - 1) * best_guess.item() | |
| token_evidence = (- loss_token[best_guess] + loss_token[-1]) * 10 | |
| token_evidence = [0] + token_evidence.tolist() | |
| # tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| word_evidence_list = [] | |
| start = 0 | |
| n_tokens = len(input_ids[0]) | |
| for token_i in range(1, n_tokens+1): | |
| span = self.tokenizer.decode(input_ids[0][start: token_i]) | |
| for _punc in punctuations: | |
| if token_i == n_tokens or _punc in span: | |
| new_span = self.tokenizer.decode( | |
| input_ids[0][start: token_i-1]).strip() | |
| if len(new_span) <= 1: | |
| break | |
| word_evidence_list.append(( | |
| new_span, | |
| np.array(token_evidence[start: token_i-1]).mean() | |
| )) | |
| start = token_i - 1 | |
| break | |
| # token_evidence_list = list(zip(tokens, token_evidence)) | |
| return best_guess_value, dist_list, word_evidence_list | |
| def generate(self, prompt, steer_values, min_length=20, max_length=100, | |
| seed=None, num_beams=1, num_beam_groups=1, do_sample=True, | |
| temperature=1, top_p=1): | |
| ''' | |
| prompt: a string | |
| steer_values | |
| min_length: minimum generation length | |
| max_length: maximum generation length | |
| seed: seed for generation. None if not specified. | |
| ''' | |
| if seed is not None: | |
| set_seed(seed) | |
| steer_values = torch.Tensor(steer_values).to( | |
| self.device) | |
| self.steer.set_value(steer_values[None]) | |
| with torch.no_grad(): | |
| inputs = self.tokenizer( | |
| prompt, return_tensors="pt").to(self.device) | |
| text = self.model.generate( | |
| **inputs, | |
| num_beams=num_beams, num_beam_groups=num_beam_groups, | |
| do_sample=do_sample, temperature=temperature, top_p=top_p, | |
| min_length=min_length, max_length=max_length, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| text = self.tokenizer.decode(text[0], skip_special_tokens=True) | |
| return text | |
| def generate_low_resource( | |
| self, prompt, steer_values, min_length=20, max_length=100, | |
| seed=None, num_beams=1, num_beam_groups=1, do_sample=True, | |
| temperature=1, top_p=1 | |
| ): | |
| ''' | |
| prompt: a string | |
| steer_values | |
| min_length: minimum generation length | |
| max_length: maximum generation length | |
| seed: seed for generation. None if not specified. | |
| ''' | |
| if seed is not None: | |
| set_seed(seed) | |
| steer_values = torch.Tensor(steer_values).to( | |
| self.device) | |
| fp16 = torch.float16 | |
| steer_values = steer_values.to(fp16) | |
| self.steer.projector1.data = self.steer.projector1.to(fp16) | |
| self.steer.projector2.data = self.steer.projector2.to(fp16) | |
| self.steer.set_value(steer_values[None]) | |
| with torch.no_grad(): | |
| input_ids = self.tokenizer( | |
| prompt, return_tensors="pt").input_ids.to(self.device) | |
| gen_tokens = self.model.generate( | |
| input_ids, | |
| num_beams=num_beams, num_beam_groups=num_beam_groups, | |
| do_sample=do_sample, temperature=temperature, top_p=top_p, | |
| min_length=min_length, max_length=max_length, | |
| pad_token_id=self.tokenizer.pad_token_id) | |
| text = self.tokenizer.batch_decode(gen_tokens)[0] | |
| # recovering | |
| fp32 = torch.float32 | |
| self.steer.projector1.data = self.steer.projector1.to(fp32) | |
| self.steer.projector2.data = self.steer.projector2.to(fp32) | |
| return text | |
| def state_dict(self): | |
| return self.steer.state_dict() | |
| def load_state_dict(self, state_dict): | |
| self.steer.load_state_dict(state_dict) | |
| def parameters(self): | |
| return self.steer.parameters() | |
| def to_device(self, device): | |
| self.model.to(device) | |
| self.device = device | |
| def regularization_term(self): | |
| return self.steer.regularization_term() | |
| def forward(self, input_ids, attention_mask, steer_values): | |
| self.steer.set_value(steer_values) | |
| output = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=input_ids) | |
| return output | |