welddy commited on
Commit
bafa499
·
verified ·
1 Parent(s): 7e5e67d

Upload modeling_limon.py

Browse files
Files changed (1) hide show
  1. modeling_limon.py +83 -15
modeling_limon.py CHANGED
@@ -3,7 +3,13 @@ import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from transformers import PreTrainedModel
6
- from .configuration_limon import LimonConfig
 
 
 
 
 
 
7
 
8
  class TimeConditionedAttention(nn.Module):
9
  def __init__(self, config):
@@ -30,7 +36,7 @@ class TimeConditionedAttention(nn.Module):
30
  class VectorFieldV2(nn.Module):
31
  def __init__(self, config):
32
  super().__init__()
33
- self.anchor_strength = config.anchor_strength
34
  self.ln1 = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
35
  self.attn = TimeConditionedAttention(config)
36
  self.ln2 = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
@@ -72,29 +78,91 @@ class ODESolverV2(nn.Module):
72
 
73
  class LimonFlowV1Model(PreTrainedModel):
74
  config_class = LimonConfig
 
 
 
 
75
  def __init__(self, config):
76
  super().__init__(config)
77
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
78
- self.pos_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
79
- self.ode_solver = ODESolverV2(VectorFieldV2(config), config.integration_steps)
 
 
 
 
80
  self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
81
  self.post_init()
82
 
83
- def forward(self, input_ids, labels=None, attention_mask=None, **kwargs):
84
- batch_size, seq_len = input_ids.shape
85
- pos = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
86
- x = self.embeddings(input_ids) + self.pos_embeddings(pos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  x = self.ode_solver(x)
88
  logits = self.head(x)
89
 
90
  loss = None
91
  if labels is not None:
92
- loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- return {
95
- "logits": logits,
96
- "loss": loss
97
- }
 
98
 
99
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
100
- return {"input_ids": input_ids}
 
 
 
 
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  from transformers import PreTrainedModel
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+
8
+ # Умный импорт: работает и локально, и на Hugging Face
9
+ try:
10
+ from .configuration_limon import LimonConfig
11
+ except ImportError:
12
+ from configuration_limon import LimonConfig
13
 
14
  class TimeConditionedAttention(nn.Module):
15
  def __init__(self, config):
 
36
  class VectorFieldV2(nn.Module):
37
  def __init__(self, config):
38
  super().__init__()
39
+ self.anchor_strength = getattr(config, "anchor_strength", 0.1)
40
  self.ln1 = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
41
  self.attn = TimeConditionedAttention(config)
42
  self.ln2 = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
 
78
 
79
  class LimonFlowV1Model(PreTrainedModel):
80
  config_class = LimonConfig
81
+
82
+ # Жесткий запрет на попытки HF создать DynamicCache
83
+ _supports_cache_class = False
84
+
85
  def __init__(self, config):
86
  super().__init__(config)
87
  self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
88
+
89
+ max_pos = getattr(config, "max_position_embeddings", getattr(config, "max_seq_len", 256))
90
+ self.pos_embeddings = nn.Embedding(max_pos, config.hidden_size)
91
+
92
+ steps = getattr(config, "integration_steps", 6)
93
+ self.ode_solver = ODESolverV2(VectorFieldV2(config), steps)
94
  self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
95
+
96
+ # ХАК для обхода внутренних проверок HF
97
+ self.config.num_hidden_layers = 1
98
+
99
  self.post_init()
100
 
101
+ def get_input_embeddings(self):
102
+ return self.embeddings
103
+
104
+ def set_input_embeddings(self, value):
105
+ self.embeddings = value
106
+
107
+ def forward(
108
+ self,
109
+ input_ids=None,
110
+ attention_mask=None,
111
+ inputs_embeds=None,
112
+ labels=None,
113
+ past_key_values=None,
114
+ use_cache=None,
115
+ output_attentions=None,
116
+ output_hidden_states=None,
117
+ return_dict=None,
118
+ **kwargs
119
+ ):
120
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
121
+
122
+ if input_ids is not None:
123
+ batch_size, seq_len = input_ids.shape
124
+ device = input_ids.device
125
+ x = self.embeddings(input_ids)
126
+ elif inputs_embeds is not None:
127
+ batch_size, seq_len, _ = inputs_embeds.shape
128
+ device = inputs_embeds.device
129
+ x = inputs_embeds
130
+ else:
131
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
132
+
133
+ pos = torch.arange(seq_len, device=device).unsqueeze(0)
134
+ x = x + self.pos_embeddings(pos)
135
+
136
  x = self.ode_solver(x)
137
  logits = self.head(x)
138
 
139
  loss = None
140
  if labels is not None:
141
+ shift_logits = logits[..., :-1, :].contiguous()
142
+ shift_labels = labels[..., 1:].contiguous()
143
+ loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
144
+
145
+ if not return_dict:
146
+ output = (logits,)
147
+ return ((loss,) + output) if loss is not None else output
148
+
149
+ # ИСПОЛЬЗУЕМ ПРАВИЛЬНЫЙ КЛАСС (WithPast)
150
+ return CausalLMOutputWithPast(
151
+ loss=loss,
152
+ logits=logits,
153
+ past_key_values=None,
154
+ hidden_states=None,
155
+ attentions=None,
156
+ )
157
 
158
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
159
+ if inputs_embeds is not None and past_key_values is None:
160
+ model_inputs = {"inputs_embeds": inputs_embeds}
161
+ else:
162
+ model_inputs = {"input_ids": input_ids}
163
 
164
+ model_inputs.update({
165
+ "attention_mask": attention_mask,
166
+ "use_cache": False,
167
+ })
168
+ return model_inputs