mesko-tts / bio_llm /model /model.py
mesklintech's picture
Publish BioVoice-TTS sparse energy checkpoint and model card
424c56c verified
import torch
from torch import nn
from bio_llm.model.candidate_retrieval import CandidateRetriever
from bio_llm.model.embedding import TokenEmbedding
from bio_llm.model.energy_head import EnergyHead
from bio_llm.model.laminar_layer import LaminarRefinement
from bio_llm.model.sparse_attention import SparseAttention
from bio_llm.utils.config import SSETConfig
class StructuredSparseEnergyTransformer(nn.Module):
"""CPU-oriented language model with sparse interaction and energy decoding."""
def __init__(self, config: SSETConfig):
super().__init__()
self.config = config
self.embedding = TokenEmbedding(config.vocab_size, config.d_model)
self.sparse_attention = SparseAttention(
d_model=config.d_model,
rank=config.low_rank,
max_seq_len=config.max_seq_len,
top_k=config.attention_top_k,
local_window=config.local_window,
memory_candidates=config.memory_candidates,
landmark_count=config.landmark_count,
content_memory_candidates=config.content_memory_candidates,
)
self.laminar = LaminarRefinement(steps=config.laminar_steps, eta=config.laminar_eta)
self.retriever = CandidateRetriever(
vocab_size=config.vocab_size,
d_model=config.d_model,
stage1_dim=config.stage1_dim,
stage1_k=config.retrieval_stage1_k,
stage2_k=config.retrieval_stage2_k,
)
self.energy_head = EnergyHead(
vocab_size=config.vocab_size,
d_model=config.d_model,
transition_rank=config.transition_rank,
)
def forward(
self,
input_ids: torch.Tensor,
target_ids: torch.Tensor | None = None,
attention_mode: str | None = None,
) -> dict[str, torch.Tensor]:
embedded = self.embedding(input_ids)
attention_state = self.sparse_attention(embedded, mode=attention_mode or self.config.attention_mode)
refined_states = self.laminar(
states=attention_state.context,
attention_indices=attention_state.attention_indices,
attention_weights=attention_state.attention_weights,
)
candidate_ids = self.retriever(
hidden_states=refined_states,
embedding_weight=self.embedding.weight,
target_ids=target_ids,
)
energy_outputs = self.energy_head(
hidden_states=refined_states,
input_ids=input_ids,
prev_tokens=input_ids,
candidate_ids=candidate_ids,
attention_indices=attention_state.attention_indices,
attention_weights=attention_state.attention_weights,
embedding_weight=self.embedding.weight,
)
return {
"hidden_states": refined_states,
"candidate_ids": candidate_ids,
"attention_indices": attention_state.attention_indices,
"attention_weights": attention_state.attention_weights,
"confidence": attention_state.confidence,
**energy_outputs,
}
@torch.no_grad()
def generate(
self,
prompt_ids: list[int],
eos_id: int,
max_new_tokens: int = 12,
attention_mode: str | None = None,
) -> list[int]:
generated = list(prompt_ids)
for _ in range(max_new_tokens):
window = generated[-self.config.max_seq_len :]
input_ids = torch.tensor([window], dtype=torch.long)
outputs = self.forward(input_ids, attention_mode=attention_mode)
last_probabilities = outputs["probabilities"][0, -1]
last_candidates = outputs["candidate_ids"][0, -1]
next_token = int(last_candidates[last_probabilities.argmax()].item())
generated.append(next_token)
if next_token == eos_id:
break
return generated