| | def custom_generate( |
| | self, |
| | input_ids, |
| | attention_mask=None, |
| | max_new_tokens=None, |
| | min_length=None, |
| | do_sample=None, |
| | early_stopping=None, |
| | num_beams=None, |
| | temperature=None, |
| | top_k=None, |
| | top_p=None, |
| | repetition_penalty=None, |
| | bad_words_ids=None, |
| | bos_token_id=None, |
| | pad_token_id=None, |
| | eos_token_id=None, |
| | streamer=None, |
| | length_penalty=None, |
| | no_repeat_ngram_size=None, |
| | num_return_sequences=None, |
| | decoder_start_token_id=None, |
| | use_cache=None, |
| | num_beam_groups=None, |
| | diversity_penalty=None, |
| | prefix_allowed_tokens_fn=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | output_scores=None, |
| | return_dict_in_generate=None, |
| | forced_bos_token_id=None, |
| | forced_eos_token_id=None, |
| | remove_invalid_values=None, |
| | synced_gpus=None, |
| | **kwargs, |
| | ): |
| | if input_ids is None or input_ids.nelement() == 0: |
| | |
| | input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device) |
| | attention_mask = torch.ones_like(input_ids).to(self.device) |
| |
|
| | device = input_ids.device |
| | with torch.no_grad(): |
| | batch_size = input_ids.shape[0] |
| | finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device) |
| | generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device) |
| |
|
| | for cur_token_idx in range(max_new_tokens): |
| | |
| | new_ids = self( |
| | input_ids[~finished_generating], |
| | attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None, |
| | **kwargs |
| | )['logits'] |
| |
|
| | |
| | new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf") |
| |
|
| | for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]): |
| | |
| | base_answer_ids = input_ids[answer_idx] |
| | new_answer_ids = new_ids[list_idx] |
| | last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max() |
| |
|
| | new_ids_sampled = torch.multinomial( |
| | torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1) |
| |
|
| | |
| | if last_token_idx + 1 >= len(base_answer_ids): |
| | |
| | new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long, |
| | device=device) |
| | input_ids = torch.cat([input_ids, new_padding], dim=-1) |
| | if attention_mask is not None: |
| | attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1) |
| |
|
| | if attention_mask is not None: |
| | attention_mask[answer_idx, last_token_idx + 1] = 1 |
| | input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled |
| | generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled |
| |
|
| | if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id: |
| | finished_generating[answer_idx] = 1 |
| |
|
| | |
| | if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"): |
| | finished_generating[answer_idx] = 1 |
| | |
| | if finished_generating.all(): |
| | break |
| |
|
| | if streamer is not None: |
| | streamer.put(new_ids_sampled) |
| |
|
| | return generated_token_ids |
| |
|
| |
|
| | def generate( |
| | self, |
| | input_ids, |
| | attention_mask=None, |
| | max_new_tokens=None, |
| | min_length=None, |
| | do_sample=None, |
| | early_stopping=None, |
| | num_beams=None, |
| | temperature=1.1, |
| | streamer=None, |
| | top_k=None, |
| | top_p=None, |
| | repetition_penalty=None, |
| | bad_words_ids=None, |
| | bos_token_id=None, |
| | pad_token_id=None, |
| | eos_token_id=None, |
| | length_penalty=None, |
| | no_repeat_ngram_size=None, |
| | num_return_sequences=None, |
| | decoder_start_token_id=None, |
| | use_cache=None, |
| | num_beam_groups=None, |
| | diversity_penalty=None, |
| | prefix_allowed_tokens_fn=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | output_scores=None, |
| | return_dict_in_generate=None, |
| | forced_bos_token_id=None, |
| | forced_eos_token_id=None, |
| | remove_invalid_values=None, |
| | synced_gpus=None, |
| | n_ahead=4, |
| | n_ahead_talk=4, |
| | merged_talk_heads=True, |
| | merged_lm_and_talk_heads=False, |
| | merged_lm_and_think_heads=True, |
| | use_concat_talk_head=True, |
| | use_shallow_think=True, |
| | use_shallow_talk=False, |
| | use_complex_think_head=False, |
| | use_complex_talk_head=True, |
| | use_weighted_talk_head=True, |
| | trust_remote_code=True, |
| | torch_dtype=torch.bfloat16, |
| | **model_kwargs, |
| | ): |
| |
|
| | if max_new_tokens is None: |
| | max_new_tokens = 128 |
| | |
| | |
| | self.max_thoughts = n_ahead + n_ahead_talk + 1 |
| | self.merged_talk_heads = merged_talk_heads |
| | self.merged_lm_and_talk_heads = merged_lm_and_talk_heads |
| | self.merged_lm_and_think_heads = merged_lm_and_think_heads |
| | self.use_concat_talk_head = use_concat_talk_head |
| | self.use_shallow_think = use_shallow_think |
| | self.use_shallow_talk = use_shallow_talk |
| | self.use_complex_think_head = use_complex_think_head |
| | self.use_complex_talk_head = use_complex_talk_head |
| | self.use_weighted_talk_head = use_weighted_talk_head |
| |
|
| | |
| | self.use_end_thought_token = True |
| | self.use_start_thought_token = True |
| | self.n_ahead = n_ahead |
| | self.n_passes = 1 |
| | self.eval_mode = True |
| | self.first_run = False |
| | self.rm_initialized = True |
| | self.original_mode = False |
| |
|
| | |
| | if isinstance(input_ids, str): |
| | input_ids = self.tokenizer.encode(input_ids, return_tensors='pt') |
| |
|
| | |
| | input_ids = input_ids.to(self.device) |
| | if attention_mask is not None: |
| | attention_mask = attention_mask.to(self.device) |
| |
|
| | generated_token_ids = custom_generate( |
| | self, |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_new_tokens=max_new_tokens, |
| | min_length=min_length, |
| | do_sample=do_sample, |
| | early_stopping=early_stopping, |
| | num_beams=num_beams, |
| | temperature=temperature, |
| | top_k=top_k, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | bad_words_ids=bad_words_ids, |
| | bos_token_id=bos_token_id, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | length_penalty=length_penalty, |
| | no_repeat_ngram_size=no_repeat_ngram_size, |
| | num_return_sequences=num_return_sequences, |
| | decoder_start_token_id=decoder_start_token_id, |
| | use_cache=use_cache, |
| | num_beam_groups=num_beam_groups, |
| | diversity_penalty=diversity_penalty, |
| | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | output_scores=output_scores, |
| | return_dict_in_generate=return_dict_in_generate, |
| | forced_bos_token_id=forced_bos_token_id, |
| | forced_eos_token_id=forced_eos_token_id, |
| | remove_invalid_values=remove_invalid_values, |
| | synced_gpus=synced_gpus, |
| | streamer=streamer, |
| | **model_kwargs, |
| | ) |
| |
|
| | return generated_token_ids |