TaoNet-pico-T1 / modeling_taonet.py
Lobakkang's picture
Upload TaoNet model to HuggingFace Hub
2981407 verified
"""
Modeling class for TaoNet model.
"""
import torch
import torch.nn as nn
from dataclasses import dataclass
from transformers import PreTrainedModel
from .model import SimpleLLM
from .configuration_taonet import TaoNetConfig
@dataclass
class InternalModelConfig:
"""Internal config for SimpleLLM."""
vocab_size: int = 25000
d_model: int = 512
d_embed_rank: int = 384
d_state: int = 512
d_ff: int = 512
n_heads: int = 4
d_kv_comp: int = 384
d_rope: int = 64
n_layers: int = 8
max_seq_len: int = 256
dropout: float = 0.02
block_arrangement: str = "layered"
ssm_per_mla: int = 3
layered_mla_num: int = 0
pad_token_id: int = 3
bos_token_id: int = 1
eos_token_id: int = 2
unk_token_id: int = 0
class TaoNetForCausalLM(PreTrainedModel):
"""TaoNet model for causal language modeling."""
config_class = TaoNetConfig
base_model_prefix = "taonet"
def __init__(self, config: TaoNetConfig):
super().__init__(config)
# Convert HF config to internal config
internal_config = InternalModelConfig(
vocab_size=config.vocab_size,
d_model=config.d_model,
d_embed_rank=config.d_embed_rank,
d_state=config.d_state,
d_ff=config.d_ff,
n_heads=config.n_heads,
d_kv_comp=config.d_kv_comp,
d_rope=config.d_rope,
n_layers=config.n_layers,
max_seq_len=config.max_seq_len,
dropout=config.dropout,
block_arrangement=config.block_arrangement,
ssm_per_mla=config.ssm_per_mla,
layered_mla_num=config.layered_mla_num,
pad_token_id=config.pad_token_id,
bos_token_id=config.bos_token_id,
eos_token_id=config.eos_token_id,
unk_token_id=config.unk_token_id,
)
self.taonet = SimpleLLM(internal_config)
# Tie the lm_head weights to the token embedding weights
self._tie_weights()
def _tie_weights(self):
"""Tie the lm_head weight to the token embedding weight."""
if hasattr(self.taonet, 'token_embedding') and hasattr(self.taonet, 'lm_head'):
# Tie the weights - make lm_head.weight reference the same tensor as token_embedding.embed.weight
self.taonet.lm_head.weight = self.taonet.token_embedding.embed.weight
def _init_weights(self, module):
"""Initialize weights (override to maintain tied weights)."""
# Let the parent handle initialization, then retie weights
super()._init_weights(module) if hasattr(super(), '_init_weights') else None
self._tie_weights()
@property
def all_tied_weights_keys(self):
"""Return the tied weights keys to satisfy transformers requirements."""
# Return as a dict with tied_weight -> main_weight mapping
return {"taonet.lm_head.weight": "taonet.token_embedding.embed.weight"}
def mark_tied_weights_as_initialized(self):
"""Mark tied weights as initialized by actually tying them together."""
# Tie the weights so they reference the same tensor
self._tie_weights()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask=None,
labels=None,
**kwargs,
):
"""Forward pass."""
logits = self.taonet(input_ids)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1)
)
return {
"loss": loss,
"logits": logits,
}
def init_ssm_states(self, batch_size: int, device: torch.device, dtype: torch.dtype):
"""Initialize SSM states for all SSM blocks."""
return self.taonet.init_ssm_states(batch_size, device, dtype)
def generate(
self,
input_ids: torch.LongTensor,
max_length: int = 100,
temperature: float = 1.0,
top_k=None,
top_p=None,
**kwargs,
):
"""Generate text using RNN-style inference with state management."""
self.taonet.eval()
batch_size = input_ids.shape[0]
device = input_ids.device
dtype = next(self.taonet.parameters()).dtype
# Initialize SSM states for all SSM blocks
states = self.taonet.init_ssm_states(batch_size, device, dtype)
current_ids = input_ids.clone()
# Process initial tokens to prime the states
with torch.no_grad():
for i in range(input_ids.shape[1]):
token_id = input_ids[:, i:i+1]
_, states = self.taonet.inference_step(token_id, states)
# Generate new tokens
for _ in range(max_length - input_ids.shape[1]):
with torch.no_grad():
# Get logits for next token using inference_step
next_token_id = current_ids[:, -1:]
logits, states = self.taonet.inference_step(next_token_id, states)
next_logits = logits / temperature
if top_k is not None:
top_k_logits, top_k_indices = torch.topk(next_logits, min(top_k, next_logits.size(-1)), dim=-1)
indices_to_remove = next_logits < top_k_logits[..., -1, None]
next_logits[indices_to_remove] = float('-inf')
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True, dim=-1)
cumsum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumsum_probs > top_p
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_logits[..., indices_to_remove] = float('-inf')
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
current_ids = torch.cat([current_ids, next_token], dim=1)
if (next_token == self.config.eos_token_id).any():
break
return current_ids