| import torch |
| from torch import nn |
|
|
| from bio_llm.model.candidate_retrieval import CandidateRetriever |
| from bio_llm.model.embedding import TokenEmbedding |
| from bio_llm.model.energy_head import EnergyHead |
| from bio_llm.model.laminar_layer import LaminarRefinement |
| from bio_llm.model.sparse_attention import SparseAttention |
| from bio_llm.utils.config import SSETConfig |
|
|
|
|
| class StructuredSparseEnergyTransformer(nn.Module): |
| """CPU-oriented language model with sparse interaction and energy decoding.""" |
|
|
| def __init__(self, config: SSETConfig): |
| super().__init__() |
| self.config = config |
| self.embedding = TokenEmbedding(config.vocab_size, config.d_model) |
| self.sparse_attention = SparseAttention( |
| d_model=config.d_model, |
| rank=config.low_rank, |
| max_seq_len=config.max_seq_len, |
| top_k=config.attention_top_k, |
| local_window=config.local_window, |
| memory_candidates=config.memory_candidates, |
| landmark_count=config.landmark_count, |
| content_memory_candidates=config.content_memory_candidates, |
| ) |
| self.laminar = LaminarRefinement(steps=config.laminar_steps, eta=config.laminar_eta) |
| self.retriever = CandidateRetriever( |
| vocab_size=config.vocab_size, |
| d_model=config.d_model, |
| stage1_dim=config.stage1_dim, |
| stage1_k=config.retrieval_stage1_k, |
| stage2_k=config.retrieval_stage2_k, |
| ) |
| self.energy_head = EnergyHead( |
| vocab_size=config.vocab_size, |
| d_model=config.d_model, |
| transition_rank=config.transition_rank, |
| ) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| target_ids: torch.Tensor | None = None, |
| attention_mode: str | None = None, |
| ) -> dict[str, torch.Tensor]: |
| embedded = self.embedding(input_ids) |
| attention_state = self.sparse_attention(embedded, mode=attention_mode or self.config.attention_mode) |
| refined_states = self.laminar( |
| states=attention_state.context, |
| attention_indices=attention_state.attention_indices, |
| attention_weights=attention_state.attention_weights, |
| ) |
| candidate_ids = self.retriever( |
| hidden_states=refined_states, |
| embedding_weight=self.embedding.weight, |
| target_ids=target_ids, |
| ) |
| energy_outputs = self.energy_head( |
| hidden_states=refined_states, |
| input_ids=input_ids, |
| prev_tokens=input_ids, |
| candidate_ids=candidate_ids, |
| attention_indices=attention_state.attention_indices, |
| attention_weights=attention_state.attention_weights, |
| embedding_weight=self.embedding.weight, |
| ) |
| return { |
| "hidden_states": refined_states, |
| "candidate_ids": candidate_ids, |
| "attention_indices": attention_state.attention_indices, |
| "attention_weights": attention_state.attention_weights, |
| "confidence": attention_state.confidence, |
| **energy_outputs, |
| } |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| prompt_ids: list[int], |
| eos_id: int, |
| max_new_tokens: int = 12, |
| attention_mode: str | None = None, |
| ) -> list[int]: |
| generated = list(prompt_ids) |
| for _ in range(max_new_tokens): |
| window = generated[-self.config.max_seq_len :] |
| input_ids = torch.tensor([window], dtype=torch.long) |
| outputs = self.forward(input_ids, attention_mode=attention_mode) |
| last_probabilities = outputs["probabilities"][0, -1] |
| last_candidates = outputs["candidate_ids"][0, -1] |
| next_token = int(last_candidates[last_probabilities.argmax()].item()) |
| generated.append(next_token) |
| if next_token == eos_id: |
| break |
| return generated |
|
|