| | """Prisma model for HuggingFace integration.
|
| |
|
| | Usage:
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer
|
| |
|
| | model = AutoModelForCausalLM.from_pretrained("y3i12/Prisma", trust_remote_code=True)
|
| | tokenizer = AutoTokenizer.from_pretrained("y3i12/Prisma")
|
| | """
|
| |
|
| | import torch
|
| | from transformers import PreTrainedModel
|
| | from transformers.modeling_outputs import CausalLMOutputWithPast
|
| |
|
| | from .configuration_prisma import PrismaConfig
|
| | from .mirrored import MirroredTransformer, MirroredConfig
|
| | from .layers import build_word_start_table, compute_word_positions
|
| |
|
| |
|
| | class PrismaForCausalLM(PreTrainedModel):
|
| | """Prisma mirrored transformer for causal language modeling."""
|
| |
|
| | config_class = PrismaConfig
|
| | _tied_weights_keys = ["transformer.lm_head.weight"]
|
| | _no_split_modules = ["MirroredBlock", "MiddleBlock"]
|
| | _keys_to_ignore_on_load_missing = [
|
| | r"transformer\..*\.rotary\.inv_freq",
|
| | r"transformer\..*\.word_rope\.word_inv_freq",
|
| | ]
|
| |
|
| | def __init__(self, config: PrismaConfig):
|
| | super().__init__(config)
|
| |
|
| | mirrored_config = MirroredConfig(
|
| | vocab_size=config.vocab_size,
|
| | hidden_size=config.hidden_size,
|
| | num_heads=config.num_heads,
|
| | num_kv_heads=config.num_kv_heads,
|
| | num_layers=config.num_layers,
|
| | n_middle=config.n_middle,
|
| | max_seq_len=config.max_seq_len,
|
| | dropout=config.dropout,
|
| | aux_skip_k=config.aux_skip_k,
|
| | aux_skip_weight=config.aux_skip_weight,
|
| | use_g2lu=config.use_g2lu,
|
| | word_rope_dims=config.word_rope_dims,
|
| | word_rope_base=config.word_rope_base,
|
| | embed_dim=config.embed_dim,
|
| | head_dim=config.head_dim,
|
| | )
|
| | self.transformer = MirroredTransformer(mirrored_config)
|
| |
|
| |
|
| | if config.word_rope_dims > 0:
|
| | self.register_buffer(
|
| | "word_start_table",
|
| | torch.zeros(config.vocab_size, dtype=torch.bool),
|
| | persistent=True,
|
| | )
|
| | else:
|
| | self.word_start_table = None
|
| |
|
| |
|
| | self._word_pos_counter = 0
|
| |
|
| | self.post_init()
|
| |
|
| | def set_tokenizer(self, tokenizer):
|
| | """Build word_start_table from tokenizer. Call this if not loading from pretrained."""
|
| | if self.config.word_rope_dims > 0:
|
| | table = build_word_start_table(tokenizer, self.config.vocab_size)
|
| | self.word_start_table = table.to(self.device)
|
| |
|
| | def get_input_embeddings(self):
|
| | return self.transformer.embed
|
| |
|
| | def set_input_embeddings(self, value):
|
| | self.transformer.embed = value
|
| |
|
| | def get_output_embeddings(self):
|
| | return self.transformer.lm_head
|
| |
|
| | def set_output_embeddings(self, new_embeddings):
|
| | self.transformer.lm_head = new_embeddings
|
| |
|
| | def tie_weights(self):
|
| | if self.config.tie_word_embeddings:
|
| | embed_dim = self.config.embed_dim or self.config.hidden_size
|
| | head_dim = self.config.head_dim or self.config.hidden_size
|
| | if embed_dim == head_dim:
|
| | self.transformer.lm_head.weight = self.transformer.embed.weight
|
| |
|
| | def forward(
|
| | self,
|
| | input_ids=None,
|
| | attention_mask=None,
|
| | past_key_values=None,
|
| | labels=None,
|
| | use_cache=False,
|
| | return_dict=True,
|
| | **kwargs,
|
| | ):
|
| |
|
| | past_kv_list = None
|
| | if past_key_values is not None:
|
| |
|
| | has_content = False
|
| | if isinstance(past_key_values, (list, tuple)):
|
| | has_content = len(past_key_values) > 0
|
| | past_kv_list = past_key_values if has_content else None
|
| | elif hasattr(past_key_values, 'get_seq_length'):
|
| | has_content = past_key_values.get_seq_length() > 0
|
| | if has_content:
|
| | past_kv_list = [past_key_values[i] for i in range(len(past_key_values))]
|
| |
|
| |
|
| | word_positions = None
|
| | if self.word_start_table is not None and self.config.word_rope_dims > 0:
|
| | if past_kv_list is not None and input_ids.size(1) == 1:
|
| |
|
| | last_token = input_ids[0, -1].item()
|
| | if self.word_start_table[last_token]:
|
| | self._word_pos_counter = 0
|
| | else:
|
| | self._word_pos_counter += 1
|
| | word_positions = torch.tensor(
|
| | [[float(self._word_pos_counter)]],
|
| | device=input_ids.device,
|
| | )
|
| | else:
|
| |
|
| | word_positions = compute_word_positions(input_ids, self.word_start_table)
|
| |
|
| | self._word_pos_counter = int(word_positions[0, -1].item())
|
| |
|
| | output = self.transformer(
|
| | input_ids,
|
| | labels=labels,
|
| | use_cache=use_cache,
|
| | past_kv=past_kv_list,
|
| | word_positions=word_positions,
|
| | )
|
| |
|
| |
|
| | new_cache = None
|
| | if use_cache and output.get("past_kv") is not None:
|
| | from transformers.cache_utils import DynamicCache
|
| | new_cache = DynamicCache()
|
| | for layer_idx, (k, v) in enumerate(output["past_kv"]):
|
| | new_cache.update(k, v, layer_idx)
|
| |
|
| | if not return_dict:
|
| | result = (output["logits"],)
|
| | if use_cache:
|
| | result += (new_cache,)
|
| | return result
|
| |
|
| | return CausalLMOutputWithPast(
|
| | loss=output.get("loss"),
|
| | logits=output["logits"],
|
| | past_key_values=new_cache,
|
| | )
|
| |
|
| | def prepare_inputs_for_generation(
|
| | self, input_ids, past_key_values=None, **kwargs
|
| | ):
|
| |
|
| | has_cache = False
|
| | if past_key_values is not None:
|
| | if hasattr(past_key_values, 'get_seq_length'):
|
| | has_cache = past_key_values.get_seq_length() > 0
|
| | elif isinstance(past_key_values, (list, tuple)):
|
| | has_cache = len(past_key_values) > 0
|
| | if has_cache:
|
| | input_ids = input_ids[:, -1:]
|
| |
|
| | return {
|
| | "input_ids": input_ids,
|
| | "past_key_values": past_key_values,
|
| | "use_cache": True,
|
| | }
|
| |
|