Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM | |
| from .utils import * | |
| from .kv_cache import initialize_past_key_values | |
| from .choices import mc_sim_7b_63 | |
| from transformers import AutoTokenizer | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from .cnets import Model | |
| from .configs import EConfig | |
| class ResBlock(nn.Module): | |
| """ | |
| A Residual Block module. | |
| This module performs a linear transformation followed by a SiLU activation, | |
| and then adds the result to the original input, creating a residual connection. | |
| Args: | |
| hidden_size (int): The size of the hidden layers in the block. | |
| """ | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, hidden_size) | |
| # Initialize as an identity mapping | |
| torch.nn.init.zeros_(self.linear.weight) | |
| # Use SiLU activation to keep consistent with the Llama model | |
| self.act = nn.SiLU() | |
| def forward(self, x): | |
| """ | |
| Forward pass of the ResBlock. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| Returns: | |
| torch.Tensor: Output after the residual connection and activation. | |
| """ | |
| return x + self.act(self.linear(x)) | |
| class EaModel(nn.Module): | |
| def __init__( | |
| self, | |
| base_model, | |
| base_model_name_or_path, | |
| ea_model_path, | |
| ): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.config = base_model.config | |
| self.hidden_size = base_model.lm_head.weight.shape[-1] | |
| self.vocab_size = base_model.lm_head.weight.shape[0] | |
| self.base_model_name_or_path = base_model_name_or_path | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) | |
| config = EConfig.from_pretrained(ea_model_path) | |
| self.ea_layer = Model(config) | |
| device = base_model.model.layers[-1].self_attn.q_proj.weight.device | |
| self.ea_layer.to(torch.float16).to(device) | |
| self.ea_layer.init_tree() | |
| def get_tokenizer(self): | |
| """Get the tokenizer of the base model. | |
| Returns: | |
| Tokenizer: The tokenizer of the base model. | |
| """ | |
| return self.tokenizer | |
| def from_pretrained( | |
| cls, | |
| base_model_path=None, | |
| ea_model_path=None, | |
| **kwargs, | |
| ): | |
| base_model = KVLlamaForCausalLM.from_pretrained( | |
| base_model_path, **kwargs | |
| ) | |
| model = cls( | |
| base_model, | |
| base_model_path, | |
| ea_model_path | |
| ) | |
| ea_layer_state_dict = torch.load(os.path.join(ea_model_path,"pytorch_model.bin"), map_location=base_model.device) | |
| model.ea_layer.load_state_dict(ea_layer_state_dict, strict=False) | |
| return model | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| labels=None, | |
| past_key_values=None, | |
| output_orig=False, | |
| position_ids=None, | |
| init=True, | |
| logits_processor=None | |
| ): | |
| with torch.inference_mode(): | |
| # Pass input through the base model | |
| outputs = self.base_model.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| position_ids=position_ids, | |
| ) | |
| if output_orig: | |
| orig = self.base_model.lm_head(outputs[0]) | |
| hidden_states = outputs[0].clone() | |
| if init: | |
| if logits_processor is not None: | |
| logits=orig[:, -1] | |
| logits=logits_processor(None,logits) | |
| probabilities = torch.nn.functional.softmax(logits, dim=1) | |
| token=torch.multinomial(probabilities, 1) | |
| else: | |
| token = torch.argmax(orig[:,-1]) | |
| token=token[None,None] | |
| input_ids=torch.cat((input_ids,token.to(input_ids.device)),dim=1) | |
| # Clone the output hidden states | |
| ea_logits = self.ea_layer.topK_genrate(hidden_states,input_ids,self.base_model.lm_head,logits_processor) | |
| if output_orig: | |
| return ea_logits, outputs, orig,hidden_states,token | |
| return ea_logits,hidden_states,token | |
| else: | |
| if output_orig: | |
| return outputs,orig,hidden_states | |
| def eagenerate( | |
| self, | |
| input_ids, | |
| temperature=0.0, | |
| top_p=0.0, | |
| top_k=0.0, | |
| max_new_tokens=512, | |
| max_length=2048, | |
| tree_choices=mc_sim_7b_63, | |
| ): | |
| if temperature>1e-5: | |
| logits_processor=prepare_logits_processor(temperature=temperature,top_p=top_p,top_k=top_k) | |
| else: | |
| logits_processor=None | |
| assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" | |
| # Avoid modifying the input_ids in-place | |
| input_ids = input_ids.clone() | |
| self.ea_layer.reset_kv() | |
| if hasattr(self, "tree_choices") and self.tree_choices == tree_choices: | |
| tree_buffers = self.tree_buffers | |
| else: | |
| tree_buffers = generate_tree_buffers( | |
| tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device | |
| ) | |
| self.tree_buffers = tree_buffers | |
| self.tree_choices = tree_choices | |
| # Initialize the past key and value states | |
| if hasattr(self, "past_key_values"): | |
| past_key_values = self.past_key_values | |
| past_key_values_data = self.past_key_values_data | |
| current_length_data = self.current_length_data | |
| # Reset the past key and value states | |
| current_length_data.zero_() | |
| else: | |
| ( | |
| past_key_values, | |
| past_key_values_data, | |
| current_length_data, | |
| ) = initialize_past_key_values(self.base_model) | |
| self.past_key_values = past_key_values | |
| self.past_key_values_data = past_key_values_data | |
| self.current_length_data = current_length_data | |
| input_len = input_ids.shape[1] | |
| reset_tree_mode(self) | |
| tree_logits, logits, hidden_state, sample_token = initialize_tree( | |
| input_ids, self, tree_buffers["tree_attn_mask"], past_key_values, logits_processor | |
| ) | |
| new_token = 0 | |
| for idx in range(max_length): | |
| candidates, cart_candidates_prob, tree_candidates = generate_candidates( | |
| tree_logits, | |
| tree_buffers["tree_indices"], | |
| tree_buffers["retrieve_indices"], | |
| sample_token, | |
| logits_processor | |
| ) | |
| logits, hidden_state_new, outputs = tree_decoding( | |
| self, | |
| tree_candidates, | |
| past_key_values, | |
| tree_buffers["tree_position_ids"], | |
| input_ids, | |
| tree_buffers["retrieve_indices"], | |
| ) | |
| best_candidate, accept_length, sample_p = evaluate_posterior( | |
| logits, candidates, logits_processor, cart_candidates_prob | |
| ) | |
| input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs( | |
| input_ids, | |
| candidates, | |
| best_candidate, | |
| accept_length, | |
| tree_buffers["retrieve_indices"], | |
| logits_processor, | |
| logits, | |
| tree_logits, | |
| new_token, | |
| past_key_values_data, | |
| current_length_data, | |
| self, | |
| hidden_state, | |
| hidden_state_new, | |
| sample_p | |
| ) | |
| if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): | |
| return input_ids | |
| if new_token > max_new_tokens: | |
| return input_ids | |
| if input_ids.shape[1] > max_length: | |
| return input_ids | |
| def ea_generate( | |
| self, | |
| input_ids, | |
| temperature=0.0, | |
| top_p=0.0, | |
| top_k=0.0, | |
| max_steps=512, | |
| tree_choices=mc_sim_7b_63, | |
| ): | |
| if temperature > 1e-5: | |
| logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) | |
| else: | |
| logits_processor = None | |
| assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" | |
| # Avoid modifying the input_ids in-place | |
| input_ids = input_ids.clone() | |
| self.ea_layer.reset_kv() | |
| if hasattr(self, "tree_choices") and self.tree_choices == tree_choices: | |
| tree_buffers = self.tree_buffers | |
| else: | |
| tree_buffers = generate_tree_buffers( | |
| tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device | |
| ) | |
| self.tree_buffers = tree_buffers | |
| self.tree_choices = tree_choices | |
| # Initialize the past key and value states | |
| if hasattr(self, "past_key_values"): | |
| past_key_values = self.past_key_values | |
| past_key_values_data = self.past_key_values_data | |
| current_length_data = self.current_length_data | |
| # Reset the past key and value states | |
| current_length_data.zero_() | |
| else: | |
| ( | |
| past_key_values, | |
| past_key_values_data, | |
| current_length_data, | |
| ) = initialize_past_key_values(self.base_model) | |
| self.past_key_values = past_key_values | |
| self.past_key_values_data = past_key_values_data | |
| self.current_length_data = current_length_data | |
| input_len = input_ids.shape[1] | |
| reset_tree_mode(self) | |
| tree_logits, logits, hidden_state, sample_token = initialize_tree( | |
| input_ids, self, tree_buffers["tree_attn_mask"], past_key_values, logits_processor | |
| ) | |
| new_token = 0 | |
| for idx in range(max_steps): | |
| candidates, cart_candidates_prob, tree_candidates = generate_candidates( | |
| tree_logits, | |
| tree_buffers["tree_indices"], | |
| tree_buffers["retrieve_indices"], | |
| sample_token, | |
| logits_processor | |
| ) | |
| logits, hidden_state_new, outputs = tree_decoding( | |
| self, | |
| tree_candidates, | |
| past_key_values, | |
| tree_buffers["tree_position_ids"], | |
| input_ids, | |
| tree_buffers["retrieve_indices"], | |
| ) | |
| best_candidate, accept_length, sample_p = evaluate_posterior( | |
| logits, candidates, logits_processor, cart_candidates_prob | |
| ) | |
| input_ids, tree_logits, new_token, hidden_state, sample_token = update_inference_inputs( | |
| input_ids, | |
| candidates, | |
| best_candidate, | |
| accept_length, | |
| tree_buffers["retrieve_indices"], | |
| logits_processor, | |
| logits, | |
| tree_logits, | |
| new_token, | |
| past_key_values_data, | |
| current_length_data, | |
| self, | |
| hidden_state, | |
| hidden_state_new, | |
| sample_p | |
| ) | |
| yield input_ids | |
| if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): | |
| break | |
| if new_token > 1024: | |
| break | |
| if input_ids.shape[1] > 1960: | |
| break | |
| def naive_generate( | |
| self, | |
| input_ids, | |
| temperature=0.0, | |
| top_p=0.0, | |
| top_k=0.0, | |
| max_steps=512, | |
| tree_choices=mc_sim_7b_63, | |
| ): | |
| if temperature > 1e-5: | |
| logits_processor = prepare_logits_processor(temperature=temperature, top_p=top_p, top_k=top_k) | |
| else: | |
| logits_processor = None | |
| assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" | |
| # Avoid modifying the input_ids in-place | |
| input_ids = input_ids.clone() | |
| self.ea_layer.reset_kv() | |
| if hasattr(self, "tree_choices") and self.tree_choices == tree_choices: | |
| tree_buffers = self.tree_buffers | |
| else: | |
| tree_buffers = generate_tree_buffers( | |
| tree_choices, device=self.base_model.model.layers[-1].self_attn.q_proj.weight.device | |
| ) | |
| self.tree_buffers = tree_buffers | |
| self.tree_choices = tree_choices | |
| # Initialize the past key and value states | |
| if hasattr(self, "past_key_values"): | |
| past_key_values = self.past_key_values | |
| past_key_values_data = self.past_key_values_data | |
| current_length_data = self.current_length_data | |
| # Reset the past key and value states | |
| current_length_data.zero_() | |
| else: | |
| ( | |
| past_key_values, | |
| past_key_values_data, | |
| current_length_data, | |
| ) = initialize_past_key_values(self.base_model) | |
| self.past_key_values = past_key_values | |
| self.past_key_values_data = past_key_values_data | |
| self.current_length_data = current_length_data | |
| input_len = input_ids.shape[1] | |
| reset_tree_mode(self) | |
| outputs = self.base_model(input_ids, past_key_values=past_key_values, use_cache=True) | |
| new_token = 0 | |
| for idx in range(max_steps): | |
| input_id = outputs.logits[:, -1:].argmax(dim=-1) | |
| outputs = self.base_model(input_id, use_cache=True, past_key_values=past_key_values) | |
| input_ids = torch.cat([input_ids, input_id], dim=-1) | |
| yield input_ids | |
| if self.tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): | |
| break | |
| if new_token > 1024: | |
| break | |
| if input_ids.shape[1] > 1960: | |
| break |