| import math
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from transformers import PreTrainedModel
|
| from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
|
| try:
|
| from .configuration_limon import LimonConfig
|
| except ImportError:
|
| from configuration_limon import LimonConfig
|
|
|
| class TimeConditionedAttention(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.num_heads = config.num_heads
|
| self.head_dim = config.hidden_size // config.num_heads
|
| self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
| self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
|
|
| def forward(self, x):
|
| batch_size, seq_len, _ = x.shape
|
| qkv = self.qkv_proj(x)
|
| q, k, v = qkv.chunk(3, dim=-1)
|
| q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
|
| mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
|
| scores.masked_fill_(mask, float('-inf'))
|
| attn = F.softmax(scores, dim=-1)
|
| out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
| return self.o_proj(out)
|
|
|
| class VectorFieldV2(nn.Module):
|
| def __init__(self, config):
|
| super().__init__()
|
| self.anchor_strength = getattr(config, "anchor_strength", 0.1)
|
| self.ln1 = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
|
| self.attn = TimeConditionedAttention(config)
|
| self.ln2 = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
|
| self.mlp = nn.Sequential(
|
| nn.Linear(config.hidden_size, config.intermediate_size),
|
| nn.GELU(),
|
| nn.Linear(config.intermediate_size, config.hidden_size),
|
| )
|
| self.time_mlp = nn.Sequential(
|
| nn.Linear(1, config.hidden_size),
|
| nn.SiLU(),
|
| nn.Linear(config.hidden_size, config.hidden_size * 4)
|
| )
|
|
|
| def forward(self, x, t, x0):
|
| t_tensor = torch.tensor([t], dtype=x.dtype, device=x.device).view(1, 1, 1)
|
| time_params = self.time_mlp(t_tensor)
|
| gamma_1, beta_1, gamma_2, beta_2 = time_params.chunk(4, dim=-1)
|
| x_anchored = x + self.anchor_strength * x0
|
| x_mod1 = self.ln1(x_anchored) * (1 + gamma_1) + beta_1
|
| dx_attn = self.attn(x_mod1)
|
| x_mod2 = self.ln2(x + dx_attn) * (1 + gamma_2) + beta_2
|
| dx_mlp = self.mlp(x_mod2)
|
| return dx_attn + dx_mlp
|
|
|
| class ODESolverV2(nn.Module):
|
| def __init__(self, vector_field, steps):
|
| super().__init__()
|
| self.vector_field = vector_field
|
| self.steps = steps
|
| def forward(self, x):
|
| dt = 1.0 / self.steps
|
| t = 0.0
|
| x0 = x.clone()
|
| for _ in range(self.steps):
|
| x = x + self.vector_field(x, t, x0) * dt
|
| t += dt
|
| return x
|
|
|
| class LimonFlowV1Model(PreTrainedModel):
|
| config_class = LimonConfig
|
|
|
|
|
| _supports_cache_class = False
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
|
| max_pos = getattr(config, "max_position_embeddings", getattr(config, "max_seq_len", 256))
|
| self.pos_embeddings = nn.Embedding(max_pos, config.hidden_size)
|
|
|
| steps = getattr(config, "integration_steps", 6)
|
| self.ode_solver = ODESolverV2(VectorFieldV2(config), steps)
|
| self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
| self.config.num_hidden_layers = 1
|
|
|
| self.post_init()
|
|
|
| def get_input_embeddings(self):
|
| return self.embeddings
|
|
|
| def set_input_embeddings(self, value):
|
| self.embeddings = value
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| attention_mask=None,
|
| inputs_embeds=None,
|
| labels=None,
|
| past_key_values=None,
|
| use_cache=None,
|
| output_attentions=None,
|
| output_hidden_states=None,
|
| return_dict=None,
|
| **kwargs
|
| ):
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
| if input_ids is not None:
|
| batch_size, seq_len = input_ids.shape
|
| device = input_ids.device
|
| x = self.embeddings(input_ids)
|
| elif inputs_embeds is not None:
|
| batch_size, seq_len, _ = inputs_embeds.shape
|
| device = inputs_embeds.device
|
| x = inputs_embeds
|
| else:
|
| raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
| pos = torch.arange(seq_len, device=device).unsqueeze(0)
|
| x = x + self.pos_embeddings(pos)
|
|
|
| x = self.ode_solver(x)
|
| logits = self.head(x)
|
|
|
| loss = None
|
| if labels is not None:
|
| shift_logits = logits[..., :-1, :].contiguous()
|
| shift_labels = labels[..., 1:].contiguous()
|
| loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
|
|
| if not return_dict:
|
| output = (logits,)
|
| return ((loss,) + output) if loss is not None else output
|
|
|
|
|
| return CausalLMOutputWithPast(
|
| loss=loss,
|
| logits=logits,
|
| past_key_values=None,
|
| hidden_states=None,
|
| attentions=None,
|
| )
|
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
|
| if inputs_embeds is not None and past_key_values is None:
|
| model_inputs = {"inputs_embeds": inputs_embeds}
|
| else:
|
| model_inputs = {"input_ids": input_ids}
|
|
|
| model_inputs.update({
|
| "attention_mask": attention_mask,
|
| "use_cache": False,
|
| })
|
| return model_inputs |