import torch from torch import nn from bio_llm.model.candidate_retrieval import CandidateRetriever from bio_llm.model.embedding import TokenEmbedding from bio_llm.model.energy_head import EnergyHead from bio_llm.model.laminar_layer import LaminarRefinement from bio_llm.model.sparse_attention import SparseAttention from bio_llm.utils.config import SSETConfig class StructuredSparseEnergyTransformer(nn.Module): """CPU-oriented language model with sparse interaction and energy decoding.""" def __init__(self, config: SSETConfig): super().__init__() self.config = config self.embedding = TokenEmbedding(config.vocab_size, config.d_model) self.sparse_attention = SparseAttention( d_model=config.d_model, rank=config.low_rank, max_seq_len=config.max_seq_len, top_k=config.attention_top_k, local_window=config.local_window, memory_candidates=config.memory_candidates, landmark_count=config.landmark_count, content_memory_candidates=config.content_memory_candidates, ) self.laminar = LaminarRefinement(steps=config.laminar_steps, eta=config.laminar_eta) self.retriever = CandidateRetriever( vocab_size=config.vocab_size, d_model=config.d_model, stage1_dim=config.stage1_dim, stage1_k=config.retrieval_stage1_k, stage2_k=config.retrieval_stage2_k, ) self.energy_head = EnergyHead( vocab_size=config.vocab_size, d_model=config.d_model, transition_rank=config.transition_rank, ) def forward( self, input_ids: torch.Tensor, target_ids: torch.Tensor | None = None, attention_mode: str | None = None, ) -> dict[str, torch.Tensor]: embedded = self.embedding(input_ids) attention_state = self.sparse_attention(embedded, mode=attention_mode or self.config.attention_mode) refined_states = self.laminar( states=attention_state.context, attention_indices=attention_state.attention_indices, attention_weights=attention_state.attention_weights, ) candidate_ids = self.retriever( hidden_states=refined_states, embedding_weight=self.embedding.weight, target_ids=target_ids, ) energy_outputs = self.energy_head( hidden_states=refined_states, input_ids=input_ids, prev_tokens=input_ids, candidate_ids=candidate_ids, attention_indices=attention_state.attention_indices, attention_weights=attention_state.attention_weights, embedding_weight=self.embedding.weight, ) return { "hidden_states": refined_states, "candidate_ids": candidate_ids, "attention_indices": attention_state.attention_indices, "attention_weights": attention_state.attention_weights, "confidence": attention_state.confidence, **energy_outputs, } @torch.no_grad() def generate( self, prompt_ids: list[int], eos_id: int, max_new_tokens: int = 12, attention_mode: str | None = None, ) -> list[int]: generated = list(prompt_ids) for _ in range(max_new_tokens): window = generated[-self.config.max_seq_len :] input_ids = torch.tensor([window], dtype=torch.long) outputs = self.forward(input_ids, attention_mode=attention_mode) last_probabilities = outputs["probabilities"][0, -1] last_candidates = outputs["candidate_ids"][0, -1] next_token = int(last_candidates[last_probabilities.argmax()].item()) generated.append(next_token) if next_token == eos_id: break return generated