import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast # Умный импорт: работает и локально, и на Hugging Face try: from .configuration_limon import LimonConfig except ImportError: from configuration_limon import LimonConfig class TimeConditionedAttention(nn.Module): def __init__(self, config): super().__init__() self.num_heads = config.num_heads self.head_dim = config.hidden_size // config.num_heads self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size) def forward(self, x): batch_size, seq_len, _ = x.shape qkv = self.qkv_proj(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim) mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool() scores.masked_fill_(mask, float('-inf')) attn = F.softmax(scores, dim=-1) out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, -1) return self.o_proj(out) class VectorFieldV2(nn.Module): def __init__(self, config): super().__init__() self.anchor_strength = getattr(config, "anchor_strength", 0.1) self.ln1 = nn.LayerNorm(config.hidden_size, elementwise_affine=False) self.attn = TimeConditionedAttention(config) self.ln2 = nn.LayerNorm(config.hidden_size, elementwise_affine=False) self.mlp = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.GELU(), nn.Linear(config.intermediate_size, config.hidden_size), ) self.time_mlp = nn.Sequential( nn.Linear(1, config.hidden_size), nn.SiLU(), nn.Linear(config.hidden_size, config.hidden_size * 4) ) def forward(self, x, t, x0): t_tensor = torch.tensor([t], dtype=x.dtype, device=x.device).view(1, 1, 1) time_params = self.time_mlp(t_tensor) gamma_1, beta_1, gamma_2, beta_2 = time_params.chunk(4, dim=-1) x_anchored = x + self.anchor_strength * x0 x_mod1 = self.ln1(x_anchored) * (1 + gamma_1) + beta_1 dx_attn = self.attn(x_mod1) x_mod2 = self.ln2(x + dx_attn) * (1 + gamma_2) + beta_2 dx_mlp = self.mlp(x_mod2) return dx_attn + dx_mlp class ODESolverV2(nn.Module): def __init__(self, vector_field, steps): super().__init__() self.vector_field = vector_field self.steps = steps def forward(self, x): dt = 1.0 / self.steps t = 0.0 x0 = x.clone() for _ in range(self.steps): x = x + self.vector_field(x, t, x0) * dt t += dt return x class LimonFlowV1Model(PreTrainedModel): config_class = LimonConfig # Жесткий запрет на попытки HF создать DynamicCache _supports_cache_class = False def __init__(self, config): super().__init__(config) self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) max_pos = getattr(config, "max_position_embeddings", getattr(config, "max_seq_len", 256)) self.pos_embeddings = nn.Embedding(max_pos, config.hidden_size) steps = getattr(config, "integration_steps", 6) self.ode_solver = ODESolverV2(VectorFieldV2(config), steps) self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # ХАК для обхода внутренних проверок HF self.config.num_hidden_layers = 1 self.post_init() def get_input_embeddings(self): return self.embeddings def set_input_embeddings(self, value): self.embeddings = value def forward( self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None: batch_size, seq_len = input_ids.shape device = input_ids.device x = self.embeddings(input_ids) elif inputs_embeds is not None: batch_size, seq_len, _ = inputs_embeds.shape device = inputs_embeds.device x = inputs_embeds else: raise ValueError("You have to specify either input_ids or inputs_embeds") pos = torch.arange(seq_len, device=device).unsqueeze(0) x = x + self.pos_embeddings(pos) x = self.ode_solver(x) logits = self.head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output # ИСПОЛЬЗУЕМ ПРАВИЛЬНЫЙ КЛАСС (WithPast) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None, ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update({ "attention_mask": attention_mask, "use_cache": False, }) return model_inputs