| import jittor as jt |
|
|
| def generate(moss, input_str, tokenizer, method, **kwargs): |
| """ |
| Choose different methods to generate sentences. |
| |
| :param input_str: The input text. |
| :param tokenizer: Tokenizer. |
| :param method: Generation method. Should be one of: ['greedy', 'sample'] |
| :param kwargs: Other parameters used for generation. |
| - max_gen_len: int. Maximum generate length. Used in all methods. |
| - temperature: float. Used in ``sample``. |
| - top_p: float. Used in ``sample``. |
| - top_k: int. Used in ``sample``. |
| """ |
| if method == "greedy": |
| return greedy_search(moss, input_str, tokenizer, **kwargs) |
| elif method == "sample": |
| return sample(moss, input_str, tokenizer, **kwargs) |
| else: |
| raise NotImplementedError( |
| f"Unsupported generation method {method}" |
| ) |
|
|
| def greedy_search(model, input_str, tokenizer, max_gen_len, |
| eos_token_id=None, pad_token_id=None): |
| model.eval() |
| if eos_token_id is None: |
| eos_token_id = tokenizer.eos_token_id |
| if pad_token_id is None and eos_token_id is not None: |
| pad_token_id = eos_token_id |
| eos_token_id_tensor = jt.Var(eos_token_id) |
|
|
| tokenized = tokenizer(input_str, return_tensors='np') |
| sentence_ids = jt.Var(tokenized['input_ids']) |
| attention_mask = jt.Var(tokenized['attention_mask']) |
| unfinished_sequences = sentence_ids.new(sentence_ids.shape[0]).fill_(1) |
| past_key_values = None |
| while True: |
| |
| if past_key_values: |
| input_ids = sentence_ids[:, -1].unsqueeze(-1) |
| else: |
| input_ids = sentence_ids |
| |
| outputs = model(input_ids, past_key_values=past_key_values, |
| attention_mask=attention_mask) |
| |
| next_token_logits = outputs['logits'][:, -1, :].float() |
| next_tokens = jt.argmax(next_token_logits, dim=-1)[0] |
|
|
| |
| next_tokens = next_tokens * unfinished_sequences + \ |
| pad_token_id * (1 - unfinished_sequences) |
| sentence_ids = jt.cat([sentence_ids, next_tokens[:, None]], dim=-1) |
| |
| past_key_values = outputs['past_key_values'] |
| attention_mask = jt.cat( |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) |
|
|
| |
| next_tokens.repeat(eos_token_id_tensor.shape[0], 1) |
| unfinished_sequences = unfinished_sequences.mul( |
| next_tokens.repeat(eos_token_id_tensor.shape[0], 1) \ |
| .not_equal(eos_token_id_tensor.unsqueeze(1)) \ |
| .prod(dim=0) |
| ) |
|
|
| jt.sync_all() |
|
|
| if unfinished_sequences.max() == 0 or sentence_ids.shape[-1] >= max_gen_len: |
| break |
|
|
| return sentence_ids.reshape([-1,]).tolist()[tokenized['input_ids'].shape[1]:] |
|
|
| def sample(model, input_str, tokenizer, max_gen_len, temperature, top_p, top_k, |
| eos_token_id=None, pad_token_id=None): |
| model.eval() |
| if eos_token_id is None: |
| eos_token_id = tokenizer.eos_token_id |
| if pad_token_id is None and eos_token_id is not None: |
| pad_token_id = eos_token_id |
| eos_token_id_tensor = jt.Var(eos_token_id) |
|
|
| tokenized = tokenizer(input_str, return_tensors='np') |
| sentence_ids = jt.Var(tokenized['input_ids']) |
| attention_mask = jt.Var(tokenized['attention_mask']) |
| unfinished_sequences = sentence_ids.new(sentence_ids.shape[0]).fill_(1) |
| past_key_values = None |
|
|
| while True: |
|
|
| |
| if past_key_values: |
| input_ids = sentence_ids[:, -1].unsqueeze(-1) |
| else: |
| input_ids = sentence_ids |
| outputs = model(input_ids, past_key_values=past_key_values, |
| attention_mask=attention_mask) |
|
|
| next_token_logits = outputs['logits'][:, -1, :].float() |
|
|
| |
| |
| scores = next_token_logits / temperature |
| |
| scores = sample_top_k(scores, top_k) |
| |
| scores = sample_top_p(scores, top_p) |
|
|
| probs = jt.nn.softmax(scores, dim=-1) |
| next_tokens = jt.multinomial(probs, num_samples=1).squeeze(1) |
| |
| next_tokens = next_tokens * unfinished_sequences + \ |
| pad_token_id * (1 - unfinished_sequences) |
|
|
| |
| sentence_ids = jt.cat([sentence_ids, next_tokens[:, None]], dim=-1) |
| past_key_values = outputs['past_key_values'] |
| attention_mask = jt.cat( |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) |
|
|
| |
| next_tokens.repeat(eos_token_id_tensor.shape[0], 1) |
| unfinished_sequences = unfinished_sequences.mul( |
| next_tokens.repeat(eos_token_id_tensor.shape[0], 1) \ |
| .not_equal(eos_token_id_tensor.unsqueeze(1)) \ |
| .prod(dim=0) |
| ) |
|
|
| jt.sync_all() |
|
|
| if unfinished_sequences.max() == 0 or sentence_ids.shape[-1] >= max_gen_len: |
| break |
|
|
| return sentence_ids.reshape([-1,]).tolist()[tokenized['input_ids'].shape[1]:] |
|
|
| def sample_top_k(scores, top_k): |
| top_k = min(top_k, scores.size(-1)) |
| |
| indices_to_remove = scores < jt.topk(scores, top_k)[0][..., -1, None] |
| scores = scores.masked_fill(indices_to_remove, -float("Inf")) |
|
|
| return scores |
|
|
| def sample_top_p(scores, top_p): |
| sorted_logits, sorted_indices = jt.sort(scores, descending=False) |
| cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|
|
| |
| sorted_indices_to_remove = cumulative_probs <= (1 - top_p) |
|
|
| |
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
| scores = scores.masked_fill(indices_to_remove, -float("Inf")) |
| |
| return scores |
|
|