| |
| import torch, math |
| import torch.nn.functional as F |
|
|
|
|
| def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
| |
| |
| if top_k > 0: |
| |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| logits[indices_to_remove] = filter_value |
| if top_p > 0.0: |
| |
| sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| |
| sorted_indices_to_remove = cumulative_probs > top_p |
| |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
|
|
| for i in range(sorted_indices.size()[0]): |
| indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
| logits[i][indices_to_remove] = filter_value |
| return logits |
|
|
|
|
| def enforce_repetition_penalty(lprobs, prev_output_tokens, repetition_penalty=1.5): |
| """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ |
| for previous_token in set(prev_output_tokens): |
| |
| if lprobs[previous_token] < 0: |
| lprobs[previous_token] *= repetition_penalty |
| else: |
| lprobs[previous_token] /= repetition_penalty |
|
|
|
|
| def switch(next_value, init, is_update): |
| is_update = is_update.type_as(next_value) |
| return (1-is_update)*init + is_update*next_value |
|
|
|
|
| def get_atten_mask(batch_size, seq_length, memory_length=0): |
| memory_attention_mask = torch.ones( |
| (batch_size, 1, seq_length, seq_length + memory_length), dtype=torch.int16) |
| memory_attention_mask = torch.tril( |
| torch.triu(memory_attention_mask, 1 - seq_length + memory_length), memory_length) |
|
|
| return memory_attention_mask |
|
|
|
|
| def get_masks_and_position_ids(data, mem_length=None): |
| |
| batch_size, seq_length = data.size() |
| |
| attention_mask = torch.ones((1, seq_length, seq_length + mem_length), device=data.device) |
| attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + mem_length), mem_length) |
| attention_mask = attention_mask.unsqueeze(1) |
| |
| position_ids = torch.arange(seq_length, dtype=torch.long, |
| device=data.device) |
| position_ids = position_ids.unsqueeze(0).expand_as(data) |
| return attention_mask, position_ids |
|
|
|
|
| def sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, max_out_seq=None, mems=None, |
| end_token_id=None, repetition_penalty=1.0, temperature=1.0, top_k=0, top_p=0.0): |
| """_summary_ |
| |
| Args: |
| model (_type_): _description_ |
| context_tokens_tensor (Tensor): [bs, seq_len] |
| context_length_tensor (Tensor): [bs, ] |
| max_out_seq (_type_, optional): _description_. Defaults to None. |
| mems (_type_, optional): _description_. Defaults to None. |
| end_token_id (_type_, optional): _description_. Defaults to None. |
| repetition_penalty (float, optional): _description_. Defaults to 1.0. |
| temperature (float, optional): _description_. Defaults to 1.0. |
| top_k (int, optional): _description_. Defaults to 0. |
| top_p (float, optional): _description_. Defaults to 0.0. |
| |
| Returns: |
| _type_: _description_ |
| """ |
| |
| model_dtype = next(model.parameters()).dtype |
| org_context_length = torch.min(context_length_tensor).item() |
| batch_size = context_tokens_tensor.shape[0] |
| tokens = context_tokens_tensor[:, :org_context_length] |
| attention_mask = get_atten_mask(batch_size, org_context_length).cuda(context_tokens_tensor.device).to(model_dtype) |
| position_ids = torch.arange(org_context_length, dtype=torch.long, |
| device=tokens.device) |
| position_ids = position_ids.unsqueeze(0).expand_as(tokens) |
|
|
| counter, mem_length = 0, 0 |
| if mems is None: |
| mems = [] |
| if end_token_id is None: |
| end_token_id = 50000 |
| if max_out_seq is None: |
| max_out_seq = 512 |
|
|
| output_tokens_lists = [] |
| |
| |
| origin_order = torch.tensor(range(batch_size), device=tokens.device) |
| output_order = [] |
|
|
| |
| log_probs_tensor = torch.tensor([0.0] * batch_size, device=tokens.device) |
| log_probs_list = [] |
|
|
| with torch.no_grad(): |
| |
| while counter < max_out_seq: |
| index = org_context_length + counter |
| if counter == 0: |
| output = model.forward(input_ids=tokens, position_ids=position_ids, |
| attention_mask=attention_mask, hidden_states=mems) |
| logits, mems = output.logits, output.hidden_states |
| else: |
| output = model.forward(input_ids=tokens[:, index - 1: index], position_ids=tokens.new_ones((1, 1)) * (index - 1), |
| attention_mask=tokens.new_ones(batch_size, 1, 1, mem_length + 1).to(model_dtype), hidden_states=mems) |
| logits, mems = output.logits, output.hidden_states |
| logits = logits[:, -1] |
| logits /= temperature |
| logits = top_k_logits(logits, top_k=top_k, top_p=top_p) |
| |
| |
| |
| log_probs = F.softmax(logits, dim=-1) |
|
|
| |
| |
| |
| |
|
|
| prev = torch.multinomial(log_probs, num_samples=1).view(-1) |
|
|
| if index < torch.max(context_length_tensor).item(): |
| prev = switch( |
| prev, context_tokens_tensor[:, index], context_length_tensor <= index) |
| |
| for i in range(batch_size): |
| if index > context_length_tensor[i] and prev[i] != end_token_id: |
| log_probs_tensor[i] += math.log(log_probs[i][prev[i]] + 1e-6) |
| if prev[i] == end_token_id: |
| log_probs_tensor[i] /= (context_length_tensor[i].cpu() - index) |
|
|
| |
| stop_idx = prev == end_token_id |
| if torch.all(stop_idx).item(): |
| output_order.extend(origin_order[stop_idx].tolist()) |
| break |
|
|
| finished = tokens[stop_idx] |
| output_tokens_lists.extend(finished.detach().cpu().tolist()) |
| log_probs_list.extend(log_probs_tensor[stop_idx].tolist()) |
| output_order.extend(origin_order[stop_idx].tolist()) |
|
|
| |
| conti_idx = (prev != end_token_id) |
| origin_order = origin_order[conti_idx] |
| tokens, prev = tokens[conti_idx], prev[conti_idx] |
| context_tokens_tensor = context_tokens_tensor[conti_idx] |
| context_length_tensor = context_length_tensor[conti_idx] |
| log_probs_tensor = log_probs_tensor[conti_idx] |
| batch_size = tokens.shape[0] |
| for im in range(len(mems)): |
| mems[im] = mems[im][conti_idx, :, :] |
|
|
| tokens = torch.cat((tokens, prev.view(batch_size, 1)), dim=-1) |
|
|
| counter += 1 |
|
|
| output_tokens_lists.extend(tokens.detach().cpu().tolist()) |
| log_probs_list.extend(log_probs_tensor.tolist()) |
| output_order.extend(origin_order.tolist()) |
| output_tokens_lists = [tokens[:tokens.index( |
| end_token_id)] if end_token_id in tokens else tokens for tokens in output_tokens_lists] |
|
|
| output_tokens_lists = [tokens for _, tokens in sorted(zip(output_order, output_tokens_lists))] |
| output_log_porbs = [prob for _, prob in sorted(zip(output_order, log_probs_list))] |
|
|
| return output_tokens_lists, output_log_porbs |
|
|
|
|
| def sample_sequence(model, tokens, attention_mask, do_sampling=True, |
| repetition_penalty=1.0, max_out_seq=None, mems=None, end_token_id=None, |
| mem_length=0, temperature=1.0, top_k=0, top_p=0.0): |
| """_summary_ |
| |
| Args: |
| model (_type_): _description_ |
| tokens (Tensor): [1, seq_len] |
| attention_mask (Tensor): [1, 1, seq_len, seq_len] |
| do_sampling (bool, optional): _description_. Defaults to True. |
| repetition_penalty (float, optional): _description_. Defaults to 1.0. |
| max_out_seq (_type_, optional): _description_. Defaults to None. |
| mems (_type_, optional): _description_. Defaults to None. |
| end_token (_type_, optional): _description_. Defaults to None. |
| mem_length (int, optional): _description_. Defaults to 0. |
| temperature (float, optional): _description_. Defaults to 1.0. |
| top_k (int, optional): _description_. Defaults to 0. |
| top_p (float, optional): _description_. Defaults to 0.0. |
| |
| Returns: |
| _type_: _description_ |
| """ |
| counter = 0 |
| if mems is None: |
| mems = [] |
| if end_token_id is None: |
| end_token_id = 50000 |
| if max_out_seq is None: |
| max_out_seq = 512 |
| org_context_length = tokens.size(1) |
| with torch.no_grad(): |
| |
| while counter < max_out_seq: |
| if counter == 0: |
| logits, *mems = model(input_ids=tokens, position_ids=None, |
| attention_mask=attention_mask, mems=mems) |
| else: |
| index = org_context_length + counter |
| logits, *mems = model(input_ids=tokens[:, index - 1: index], position_ids=None, |
| attention_mask=tokens.new_ones(1, 1, 1, mem_length + 1), mems=mems) |
| logits = logits[:, -1] |
| logits /= temperature |
| if do_sampling: |
| logits = top_k_logits(logits, top_k=top_k, top_p=top_p) |
| log_probs = F.softmax(logits, dim=-1) |
|
|
| if repetition_penalty != 1.0: |
| enforce_repetition_penalty( |
| log_probs[0, :], tokens[0, :], repetition_penalty) |
| prev = torch.multinomial(log_probs, num_samples=1)[0] |
| is_end = (prev == end_token_id) |
| if is_end: |
| break |
| tokens = torch.cat((tokens, prev.view(1, 1)), dim=1) |
| counter += 1 |
|
|
| output_tokens_list = tokens.detach().cpu().tolist() |
| if end_token_id in output_tokens_list: |
| output_tokens_list = output_tokens_list[:output_tokens_list.index( |
| end_token_id)] |
|
|
| return output_tokens_list[0], mems |
|
|