| """
|
| Modeling class for TaoNet model.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from dataclasses import dataclass
|
| from transformers import PreTrainedModel
|
|
|
| from .model import SimpleLLM
|
| from .configuration_taonet import TaoNetConfig
|
|
|
|
|
| @dataclass
|
| class InternalModelConfig:
|
| """Internal config for SimpleLLM."""
|
| vocab_size: int = 25000
|
| d_model: int = 512
|
| d_embed_rank: int = 384
|
| d_state: int = 512
|
| d_ff: int = 512
|
| n_heads: int = 4
|
| d_kv_comp: int = 384
|
| d_rope: int = 64
|
| n_layers: int = 8
|
| max_seq_len: int = 256
|
| dropout: float = 0.02
|
| block_arrangement: str = "layered"
|
| ssm_per_mla: int = 3
|
| layered_mla_num: int = 0
|
| pad_token_id: int = 3
|
| bos_token_id: int = 1
|
| eos_token_id: int = 2
|
| unk_token_id: int = 0
|
|
|
|
|
| class TaoNetForCausalLM(PreTrainedModel):
|
| """TaoNet model for causal language modeling."""
|
|
|
| config_class = TaoNetConfig
|
| base_model_prefix = "taonet"
|
|
|
| def __init__(self, config: TaoNetConfig):
|
| super().__init__(config)
|
|
|
|
|
| internal_config = InternalModelConfig(
|
| vocab_size=config.vocab_size,
|
| d_model=config.d_model,
|
| d_embed_rank=config.d_embed_rank,
|
| d_state=config.d_state,
|
| d_ff=config.d_ff,
|
| n_heads=config.n_heads,
|
| d_kv_comp=config.d_kv_comp,
|
| d_rope=config.d_rope,
|
| n_layers=config.n_layers,
|
| max_seq_len=config.max_seq_len,
|
| dropout=config.dropout,
|
| block_arrangement=config.block_arrangement,
|
| ssm_per_mla=config.ssm_per_mla,
|
| layered_mla_num=config.layered_mla_num,
|
| pad_token_id=config.pad_token_id,
|
| bos_token_id=config.bos_token_id,
|
| eos_token_id=config.eos_token_id,
|
| unk_token_id=config.unk_token_id,
|
| )
|
|
|
| self.taonet = SimpleLLM(internal_config)
|
|
|
|
|
| self._tie_weights()
|
|
|
| def _tie_weights(self):
|
| """Tie the lm_head weight to the token embedding weight."""
|
| if hasattr(self.taonet, 'token_embedding') and hasattr(self.taonet, 'lm_head'):
|
|
|
| self.taonet.lm_head.weight = self.taonet.token_embedding.embed.weight
|
|
|
| def _init_weights(self, module):
|
| """Initialize weights (override to maintain tied weights)."""
|
|
|
| super()._init_weights(module) if hasattr(super(), '_init_weights') else None
|
| self._tie_weights()
|
|
|
| @property
|
| def all_tied_weights_keys(self):
|
| """Return the tied weights keys to satisfy transformers requirements."""
|
|
|
| return {"taonet.lm_head.weight": "taonet.token_embedding.embed.weight"}
|
|
|
| def mark_tied_weights_as_initialized(self):
|
| """Mark tied weights as initialized by actually tying them together."""
|
|
|
| self._tie_weights()
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.LongTensor,
|
| attention_mask=None,
|
| labels=None,
|
| **kwargs,
|
| ):
|
| """Forward pass."""
|
| logits = self.taonet(input_ids)
|
|
|
| loss = None
|
| if labels is not None:
|
| shift_logits = logits[..., :-1, :].contiguous()
|
| shift_labels = labels[..., 1:].contiguous()
|
|
|
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| loss = loss_fct(
|
| shift_logits.view(-1, self.config.vocab_size),
|
| shift_labels.view(-1)
|
| )
|
|
|
| return {
|
| "loss": loss,
|
| "logits": logits,
|
| }
|
|
|
| def init_ssm_states(self, batch_size: int, device: torch.device, dtype: torch.dtype):
|
| """Initialize SSM states for all SSM blocks."""
|
| return self.taonet.init_ssm_states(batch_size, device, dtype)
|
|
|
| def generate(
|
| self,
|
| input_ids: torch.LongTensor,
|
| max_length: int = 100,
|
| temperature: float = 1.0,
|
| top_k=None,
|
| top_p=None,
|
| **kwargs,
|
| ):
|
| """Generate text using RNN-style inference with state management."""
|
| self.taonet.eval()
|
| batch_size = input_ids.shape[0]
|
| device = input_ids.device
|
| dtype = next(self.taonet.parameters()).dtype
|
|
|
|
|
| states = self.taonet.init_ssm_states(batch_size, device, dtype)
|
|
|
| current_ids = input_ids.clone()
|
|
|
|
|
| with torch.no_grad():
|
| for i in range(input_ids.shape[1]):
|
| token_id = input_ids[:, i:i+1]
|
| _, states = self.taonet.inference_step(token_id, states)
|
|
|
|
|
| for _ in range(max_length - input_ids.shape[1]):
|
| with torch.no_grad():
|
|
|
| next_token_id = current_ids[:, -1:]
|
| logits, states = self.taonet.inference_step(next_token_id, states)
|
|
|
| next_logits = logits / temperature
|
|
|
| if top_k is not None:
|
| top_k_logits, top_k_indices = torch.topk(next_logits, min(top_k, next_logits.size(-1)), dim=-1)
|
| indices_to_remove = next_logits < top_k_logits[..., -1, None]
|
| next_logits[indices_to_remove] = float('-inf')
|
|
|
| if top_p is not None:
|
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True, dim=-1)
|
| cumsum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| sorted_indices_to_remove = cumsum_probs > top_p
|
| sorted_indices_to_remove[..., 0] = False
|
| indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| next_logits[..., indices_to_remove] = float('-inf')
|
|
|
| probs = torch.softmax(next_logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
| current_ids = torch.cat([current_ids, next_token], dim=1)
|
|
|
| if (next_token == self.config.eos_token_id).any():
|
| break
|
|
|
| return current_ids
|
|
|