Wolfvin commited on
Commit
46c5bd3
Β·
verified Β·
1 Parent(s): 9c1a00d

Upload diffusion_llm/model/evoformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion_llm/model/evoformer.py +696 -0
diffusion_llm/model/evoformer.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AAM Diffusion LLM β€” Evoformer Feedback System
2
+
3
+ Adapted from Losion/AlphaFold2: iterative bidirectional feedback
4
+ at multiple architecture levels.
5
+
6
+ For AAM, the most relevant levels:
7
+ Level 1 β€” Inter-Layer Recycling: Layer deep ↔ Layer shallow
8
+ Level 2 β€” Bidirectional Token Update: Token old ↔ Token new
9
+ Level 3 β€” Decoder ↔ Predict: Narrative output ↔ Graph conditioning
10
+ Level 4 β€” Prediction β†’ Context: Predicted narrative refines graph understanding
11
+ Level 5 β€” Router-Expert Co-evolution: Graph node ↔ Sentence arrangement
12
+
13
+ Core Principle: "Whenever there are two related representations, replace
14
+ one-way information flow with iterative bidirectional dialogue."
15
+
16
+ This is PERFECT for AAM's Predictive Coding:
17
+ predict(X) β†’ observe(Y) β†’ belief_update(Ξ”)
18
+
19
+ Evoformer makes this bidirectional and iterative.
20
+
21
+ Level 5 (RouterExpertCoevolve) β€” AAM-specific adaptation:
22
+ In Losion, this handles router ↔ MoE expert co-evolution.
23
+ For AAM, this handles: graph node ↔ sentence arrangement co-evolution.
24
+ The co-evolve state captures the "negotiation" between graph
25
+ understanding and narrative output β€” each side adjusts based on
26
+ the other's current state, creating an iterative dialogue where
27
+ better graph understanding leads to better narrative, and better
28
+ narrative feedback refines graph understanding.
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import math
34
+ from dataclasses import dataclass
35
+ from typing import Any, Dict, List, Optional, Tuple
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+
41
+
42
+ @dataclass
43
+ class EvoformerConfig:
44
+ """Configuration for Evoformer Feedback System.
45
+
46
+ Attributes:
47
+ d_model: Model hidden dimension.
48
+ n_recycling_steps: Number of recycling iterations.
49
+ dropout: Dropout rate for all sub-modules.
50
+ use_layer_recycling: Enable Level 1 (inter-layer recycling).
51
+ use_token_recycling: Enable Level 2 (bidirectional token update).
52
+ use_decoder_feedback: Enable Level 3 (decoder-predict feedback).
53
+ use_prediction_recycling: Enable Level 4 (prediction-context recycling).
54
+ use_router_coevolve: Enable Level 5 (router-expert co-evolution).
55
+ d_pair: Pair representation dimension for co-evolution state.
56
+ 0 means use d_model.
57
+ min_recycling_improvement: Minimum improvement threshold for recycling.
58
+ """
59
+
60
+ d_model: int = 768
61
+ n_recycling_steps: int = 3
62
+ dropout: float = 0.0
63
+ use_layer_recycling: bool = True
64
+ use_token_recycling: bool = True
65
+ use_decoder_feedback: bool = True
66
+ use_prediction_recycling: bool = True
67
+ use_router_coevolve: bool = True
68
+ d_pair: int = 0 # 0 = use d_model
69
+ min_recycling_improvement: float = 1e-4
70
+
71
+
72
+ class LayerRecyclingBlock(nn.Module):
73
+ """Level 1: Bidirectional feedback between deep and shallow layers.
74
+
75
+ Losion v1.9.0 gradient-flow fix: deep layers also receive a small
76
+ revision residual (0.05 multiplier) so that ``recycled[-1]`` carries
77
+ gradient through the revision path back to all layer_recycling
78
+ parameters. Without this, deep layers get no revision and the
79
+ gradient from the final output cannot flow back through the
80
+ revision path.
81
+ """
82
+
83
+ def __init__(self, d_model: int, n_recycling_steps: int = 2, dropout: float = 0.0) -> None:
84
+ super().__init__()
85
+ self.d_model = d_model
86
+ self.n_recycling_steps = n_recycling_steps
87
+
88
+ self.shallow_query_proj = nn.Linear(d_model, d_model, bias=False)
89
+ self.deep_key_proj = nn.Linear(d_model, d_model, bias=False)
90
+ self.deep_value_proj = nn.Linear(d_model, d_model, bias=False)
91
+ self.revision_proj = nn.Linear(d_model, d_model, bias=False)
92
+
93
+ self.revision_gate = nn.Sequential(
94
+ nn.Linear(d_model * 2, 1, bias=False),
95
+ nn.Sigmoid(),
96
+ )
97
+
98
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else None
99
+ self.scale = math.sqrt(d_model)
100
+
101
+ # Losion v1.9.0: deep-layer revision multiplier (small but nonzero
102
+ # to maintain gradient flow through the revision path).
103
+ self.deep_revision_multiplier: float = 0.05
104
+
105
+ def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
106
+ if len(hidden_states) < 2:
107
+ return hidden_states
108
+
109
+ n_layers = len(hidden_states)
110
+ mid = n_layers // 2
111
+ shallow_repr = torch.stack(hidden_states[:mid], dim=0).mean(dim=0)
112
+ deep_repr = torch.stack(hidden_states[mid:], dim=0).mean(dim=0)
113
+
114
+ q = self.shallow_query_proj(shallow_repr)
115
+ k = self.deep_key_proj(deep_repr)
116
+ v = self.deep_value_proj(deep_repr)
117
+
118
+ k_mean = k.mean(dim=1, keepdim=True)
119
+ v_mean = v.mean(dim=1, keepdim=True)
120
+
121
+ scores = torch.matmul(q, k_mean.transpose(-2, -1)) / self.scale
122
+ attn = F.softmax(scores, dim=-1)
123
+
124
+ if self.dropout is not None:
125
+ attn = self.dropout(attn)
126
+
127
+ revision = torch.matmul(attn, v_mean)
128
+ revision = self.revision_proj(revision)
129
+
130
+ gate = self.revision_gate(torch.cat([shallow_repr, revision], dim=-1))
131
+ revision = gate * revision
132
+
133
+ revised = []
134
+ for i, h in enumerate(hidden_states):
135
+ if i < mid:
136
+ revised.append(h + revision * (0.1 if i < mid // 2 else 0.2))
137
+ else:
138
+ # Losion v1.9.0 fix: deep layers receive a small revision
139
+ # residual so gradient flows from recycled[-1] back through
140
+ # the revision path to all layer_recycling parameters.
141
+ revised.append(h + revision * self.deep_revision_multiplier)
142
+
143
+ return revised
144
+
145
+
146
+ class BidirectionalTokenUpdate(nn.Module):
147
+ """Level 2: Later tokens revise earlier token representations."""
148
+
149
+ def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.0) -> None:
150
+ super().__init__()
151
+ self.d_model = d_model
152
+ self.n_heads = n_heads
153
+ self.d_kv = d_model // n_heads
154
+
155
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
156
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
157
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
158
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
159
+
160
+ self.gate = nn.Sequential(
161
+ nn.Linear(d_model, 1, bias=False),
162
+ nn.Sigmoid(),
163
+ )
164
+
165
+ self.norm = nn.RMSNorm(d_model)
166
+ self.dropout_mod = nn.Dropout(dropout) if dropout > 0 else None
167
+ self.scale = math.sqrt(self.d_kv)
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ batch, seq_len, _ = x.shape
171
+ if seq_len <= 1:
172
+ return x
173
+
174
+ q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2)
175
+ k = self.k_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2)
176
+ v = self.v_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2)
177
+
178
+ scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
179
+ attn = F.softmax(scores, dim=-1, dtype=torch.float32).to(x.dtype)
180
+
181
+ if self.dropout_mod is not None:
182
+ attn = self.dropout_mod(attn)
183
+
184
+ backward_info = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
185
+ backward_info = self.out_proj(backward_info)
186
+
187
+ gate = self.gate(x)
188
+ revised = x + gate * backward_info
189
+ revised = self.norm(revised)
190
+
191
+ return revised
192
+
193
+
194
+ class DecoderPredictFeedback(nn.Module):
195
+ """Level 3: Bidirectional feedback between decoder output and graph prediction.
196
+
197
+ AAM-specific: narrative output revises graph conditioning.
198
+ Predict v1 β†’ Decoder refine β†’ feedback β†’ Update v1 β†’ loop
199
+ """
200
+
201
+ def __init__(self, d_model: int, n_iterations: int = 2, dropout: float = 0.0) -> None:
202
+ super().__init__()
203
+ self.d_model = d_model
204
+ self.n_iterations = n_iterations
205
+
206
+ self.feedback_proj = nn.Sequential(
207
+ nn.Linear(d_model, d_model, bias=False),
208
+ nn.SiLU(),
209
+ nn.Linear(d_model, d_model, bias=False),
210
+ )
211
+
212
+ self.feedback_gate = nn.Sequential(
213
+ nn.Linear(d_model, 1, bias=False),
214
+ nn.Sigmoid(),
215
+ )
216
+
217
+ self.norm = nn.RMSNorm(d_model)
218
+ self.dropout_mod = nn.Dropout(dropout) if dropout > 0 else None
219
+
220
+ def forward(self, hidden_state: torch.Tensor, decoder_output: torch.Tensor) -> torch.Tensor:
221
+ delta = decoder_output - hidden_state
222
+ feedback = self.feedback_proj(delta)
223
+ gate = self.feedback_gate(hidden_state)
224
+ feedback = gate * feedback
225
+
226
+ if self.dropout_mod is not None:
227
+ feedback = self.dropout_mod(feedback)
228
+
229
+ updated = self.norm(hidden_state + feedback)
230
+ return updated
231
+
232
+
233
+ class PredictionContextRecycling(nn.Module):
234
+ """Level 4: Predicted narrative revises graph understanding.
235
+
236
+ AAM-specific: the generated narrative can refine how we understand
237
+ the graph, creating a feedback loop between output and input.
238
+ """
239
+
240
+ def __init__(self, d_model: int, dropout: float = 0.0) -> None:
241
+ super().__init__()
242
+ self.d_model = d_model
243
+
244
+ self.pred_proj = nn.Linear(d_model, d_model, bias=False)
245
+ self.context_query = nn.Linear(d_model, d_model, bias=False)
246
+ self.pred_key = nn.Linear(d_model, d_model, bias=False)
247
+ self.pred_value = nn.Linear(d_model, d_model, bias=False)
248
+ self.revision_proj = nn.Linear(d_model, d_model, bias=False)
249
+ self.revision_gate = nn.Sequential(
250
+ nn.Linear(d_model, 1, bias=False),
251
+ nn.Sigmoid(),
252
+ )
253
+
254
+ self.norm = nn.RMSNorm(d_model)
255
+ self.dropout_mod = nn.Dropout(dropout) if dropout > 0 else None
256
+ self.scale = math.sqrt(d_model)
257
+
258
+ def forward(self, hidden_states: torch.Tensor, prediction_logits: torch.Tensor) -> torch.Tensor:
259
+ batch, seq_len, _ = hidden_states.shape
260
+
261
+ if prediction_logits.shape[-1] != self.d_model:
262
+ pred_repr = self.pred_proj(prediction_logits[:, -1:, :self.d_model]
263
+ if prediction_logits.dim() == 3
264
+ else prediction_logits.unsqueeze(1))
265
+ else:
266
+ pred_repr = prediction_logits[:, -1:, :] if prediction_logits.dim() == 3 else prediction_logits.unsqueeze(1)
267
+
268
+ q = self.context_query(hidden_states)
269
+ k = self.pred_key(pred_repr)
270
+ v = self.pred_value(pred_repr)
271
+
272
+ scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
273
+ attn = F.softmax(scores, dim=-2)
274
+
275
+ if self.dropout_mod is not None:
276
+ attn = self.dropout_mod(attn)
277
+
278
+ revision = torch.matmul(attn, v)
279
+ revision = self.revision_proj(revision)
280
+
281
+ gate = self.revision_gate(hidden_states)
282
+ revised = hidden_states + gate * revision
283
+ revised = self.norm(revised)
284
+
285
+ return revised
286
+
287
+
288
+ class RouterExpertCoevolve(nn.Module):
289
+ """Level 5: Graph node ↔ sentence arrangement co-evolution.
290
+
291
+ Adapted from Losion's RouterExpertCoevolve (router ↔ MoE expert
292
+ co-evolution). In Losion, the router distributes tokens to MoE
293
+ experts, and expert outputs refine the router's decisions β€” a
294
+ bidirectional negotiation.
295
+
296
+ For AAM, the co-evolution is between:
297
+ - Graph nodes: evidence from RSVS graph (the "router" side β€”
298
+ which evidence to attend to)
299
+ - Sentence arrangement: narrative output (the "expert" side β€”
300
+ how to express the evidence in natural language)
301
+
302
+ The co-evolve state captures the "negotiation" between graph
303
+ understanding and narrative output: each side adjusts based on
304
+ the other's current state, creating an iterative dialogue where
305
+ better graph understanding leads to better narrative, and better
306
+ narrative feedback refines graph understanding.
307
+
308
+ Key design (from Losion v1.9.0):
309
+ - ``update_state()`` returns a **differentiable** tensor so
310
+ gradient flows through the revision path to all
311
+ RouterExpertCoevolve parameters.
312
+ - The internal buffer is updated with **detached** values to
313
+ prevent unbounded gradient accumulation across training steps.
314
+
315
+ Args:
316
+ d_model: Model hidden dimension.
317
+ d_pair: Pair (co-evolution state) dimension. 0 means use d_model.
318
+ n_experts: Number of routing experts (graph attention heads).
319
+ dropout: Dropout rate.
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ d_model: int,
325
+ d_pair: int = 0,
326
+ n_experts: int = 4,
327
+ dropout: float = 0.0,
328
+ ) -> None:
329
+ super().__init__()
330
+ self.d_model = d_model
331
+ self.d_pair = d_pair if d_pair > 0 else d_model
332
+ self.n_experts = n_experts
333
+
334
+ # ── Graph (router) side β€” projects graph representations ──
335
+ self.graph_router = nn.Linear(d_model, n_experts, bias=False)
336
+ self.graph_adjust_proj = nn.Linear(d_model, self.d_pair, bias=False)
337
+
338
+ # ── Narrative (expert) side β€” projects narrative representations ──
339
+ self.narrative_adjust_proj = nn.Linear(d_model, self.d_pair, bias=False)
340
+
341
+ # ── Co-evolution gate: learns how much each side influences ──
342
+ # the negotiation state
343
+ self.coevolve_gate = nn.Sequential(
344
+ nn.Linear(self.d_pair * 2, self.d_pair, bias=False),
345
+ nn.SiLU(),
346
+ nn.Linear(self.d_pair, self.d_pair, bias=False),
347
+ nn.Sigmoid(),
348
+ )
349
+
350
+ # ── Output projections back to d_model ──
351
+ self.graph_out_proj = nn.Linear(self.d_pair, d_model, bias=False)
352
+ self.narrative_out_proj = nn.Linear(self.d_pair, d_model, bias=False)
353
+
354
+ # ── Normalization ──
355
+ self.norm_graph = nn.RMSNorm(d_model)
356
+ self.norm_narrative = nn.RMSNorm(d_model)
357
+
358
+ self.dropout_mod = nn.Dropout(dropout) if dropout > 0 else None
359
+
360
+ # ── Buffers (detached from computation graph) ──
361
+ # Co-evolve state: the shared negotiation state between
362
+ # graph understanding and narrative output.
363
+ self.register_buffer("coevolve_state", torch.zeros(1, 1, self.d_pair))
364
+
365
+ # Routing adjustment: influences which graph nodes (evidence)
366
+ # receive more attention β€” the graph-side "opinion".
367
+ self.register_buffer("routing_adjustment", torch.zeros(1, self.n_experts))
368
+
369
+ def get_routing_adjustment(self) -> torch.Tensor:
370
+ """Return routing adjustment based on current co-evolve state.
371
+
372
+ The adjustment influences which graph nodes (evidence) receive
373
+ more attention β€” it is the graph-side "opinion" derived from
374
+ the current negotiation state between graph understanding and
375
+ narrative output.
376
+
377
+ Returns:
378
+ Tensor of shape ``(1, n_experts)`` with routing adjustments.
379
+ """
380
+ # Compute fresh adjustment from the current co-evolve state
381
+ state_flat = self.coevolve_state.squeeze(1) # (1, d_pair)
382
+ adj = self.graph_router(state_flat) # (1, n_experts)
383
+ return adj + self.routing_adjustment
384
+
385
+ def update_state(
386
+ self,
387
+ graph_repr: torch.Tensor,
388
+ narrative_repr: torch.Tensor,
389
+ ) -> torch.Tensor:
390
+ """Update co-evolve state; return differentiable tensor for gradient flow.
391
+
392
+ Losion v1.9.0 pattern: the returned tensor is differentiable,
393
+ so gradient flows back through the revision path to all
394
+ RouterExpertCoevolve parameters. However, the buffer is
395
+ updated with detached values to prevent unbounded gradient
396
+ accumulation across training steps.
397
+
398
+ This captures the "negotiation" between:
399
+ - Graph understanding: which evidence nodes are most relevant
400
+ - Narrative output: how the evidence is being expressed
401
+
402
+ Each side adjusts the co-evolve state based on its current
403
+ representation, and the gate learns the optimal balance.
404
+
405
+ Args:
406
+ graph_repr: Graph node representations ``(B, S_g, d_model)``.
407
+ Evidence from RSVS graph.
408
+ narrative_repr: Narrative representations ``(B, S_n, d_model)``.
409
+ Sentence arrangement output.
410
+
411
+ Returns:
412
+ Differentiable co-evolve state of shape ``(B, 1, d_pair)``.
413
+ """
414
+ # Project both sides into the co-evolution space
415
+ g_adj = self.graph_adjust_proj(graph_repr) # (B, S_g, d_pair)
416
+ n_adj = self.narrative_adjust_proj(narrative_repr) # (B, S_n, d_pair)
417
+
418
+ # Aggregate across sequence dimension (mean pooling)
419
+ g_pool = g_adj.mean(dim=1, keepdim=True) # (B, 1, d_pair)
420
+ n_pool = n_adj.mean(dim=1, keepdim=True) # (B, 1, d_pair)
421
+
422
+ # Co-evolution gate: learns the negotiation balance between
423
+ # graph understanding and narrative output
424
+ combined = torch.cat([g_pool, n_pool], dim=-1) # (B, 1, d_pair*2)
425
+ gate = self.coevolve_gate(combined) # (B, 1, d_pair)
426
+
427
+ # New state = gated negotiation between graph and narrative,
428
+ # blended with the previous state for stability
429
+ new_state = gate * (g_pool + n_pool) + (1.0 - gate) * self.coevolve_state
430
+
431
+ # IMPORTANT (Losion v1.9.0): Return differentiable version so
432
+ # gradient flows through new_state back to all
433
+ # RouterExpertCoevolve parameters.
434
+ differentiable_state = new_state
435
+
436
+ # Update buffer detached β€” prevents cross-step gradient
437
+ # accumulation while keeping the state current for the next
438
+ # forward pass.
439
+ with torch.no_grad():
440
+ self.coevolve_state.copy_(new_state.detach())
441
+ # Also update routing adjustment based on new state
442
+ adj = self.graph_router(new_state.squeeze(1)) # (B, n_experts)
443
+ self.routing_adjustment.copy_(adj.detach().mean(dim=0, keepdim=True))
444
+
445
+ return differentiable_state
446
+
447
+ def forward(
448
+ self,
449
+ graph_repr: torch.Tensor,
450
+ narrative_repr: torch.Tensor,
451
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
452
+ """Co-evolve graph and narrative representations.
453
+
454
+ This is the main entry point. It updates the co-evolve state
455
+ (capturing the negotiation between graph understanding and
456
+ narrative output) and applies the resulting adjustments to
457
+ both representations.
458
+
459
+ The co-evolution works as follows:
460
+ 1. Graph and narrative representations are projected into a
461
+ shared co-evolution space.
462
+ 2. A gated negotiation combines both perspectives.
463
+ 3. The resulting state adjusts both graph understanding
464
+ (which evidence to attend to) and narrative output
465
+ (how to express the evidence).
466
+
467
+ Args:
468
+ graph_repr: Graph node representations ``(B, S_g, d_model)``.
469
+ Evidence from RSVS graph.
470
+ narrative_repr: Narrative representations ``(B, S_n, d_model)``.
471
+ Sentence arrangement output.
472
+
473
+ Returns:
474
+ Tuple of ``(updated_graph, updated_narrative)`` β€” both
475
+ revised through the co-evolution negotiation.
476
+ """
477
+ # Step 1: Update co-evolve state, get differentiable state
478
+ # (gradient flows through this to all RouterExpertCoevolve params)
479
+ coevolve = self.update_state(graph_repr, narrative_repr) # (B, 1, d_pair)
480
+
481
+ # Step 2: Expand to match input sequence lengths
482
+ coevolve_graph = coevolve.expand(-1, graph_repr.shape[1], -1) # (B, S_g, d_pair)
483
+ coevolve_narrative = coevolve.expand(-1, narrative_repr.shape[1], -1) # (B, S_n, d_pair)
484
+
485
+ # Step 3: Project back to d_model
486
+ graph_adj = self.graph_out_proj(coevolve_graph) # (B, S_g, d_model)
487
+ narrative_adj = self.narrative_out_proj(coevolve_narrative) # (B, S_n, d_model)
488
+
489
+ # Step 4: Apply dropout
490
+ if self.dropout_mod is not None:
491
+ graph_adj = self.dropout_mod(graph_adj)
492
+ narrative_adj = self.dropout_mod(narrative_adj)
493
+
494
+ # Step 5: Residual connection + normalization
495
+ updated_graph = self.norm_graph(graph_repr + graph_adj)
496
+ updated_narrative = self.norm_narrative(narrative_repr + narrative_adj)
497
+
498
+ return updated_graph, updated_narrative
499
+
500
+
501
+ class EvoformerManager(nn.Module):
502
+ """Manages Evoformer feedback levels for AAM Diffusion LLM.
503
+
504
+ Levels:
505
+ 1. LayerRecyclingBlock β€” inter-layer bidirectional feedback
506
+ 2. BidirectionalTokenUpdate β€” token-level bidirectional update
507
+ 3. DecoderPredictFeedback β€” decoder ↔ graph prediction feedback
508
+ 4. PredictionContextRecycling β€” prediction β†’ context recycling
509
+ 5. RouterExpertCoevolve β€” graph node ↔ sentence arrangement co-evolution
510
+ """
511
+
512
+ def __init__(self, config: EvoformerConfig) -> None:
513
+ super().__init__()
514
+ self.config = config
515
+
516
+ if config.use_layer_recycling:
517
+ self.layer_recycling = LayerRecyclingBlock(
518
+ d_model=config.d_model,
519
+ n_recycling_steps=config.n_recycling_steps,
520
+ dropout=config.dropout,
521
+ )
522
+ else:
523
+ self.layer_recycling = None
524
+
525
+ if config.use_token_recycling:
526
+ self.bidirectional_token = BidirectionalTokenUpdate(
527
+ d_model=config.d_model,
528
+ n_heads=max(1, config.d_model // 128),
529
+ dropout=config.dropout,
530
+ )
531
+ else:
532
+ self.bidirectional_token = None
533
+
534
+ if config.use_decoder_feedback:
535
+ self.decoder_feedback = DecoderPredictFeedback(
536
+ d_model=config.d_model,
537
+ n_iterations=config.n_recycling_steps,
538
+ dropout=config.dropout,
539
+ )
540
+ else:
541
+ self.decoder_feedback = None
542
+
543
+ if config.use_prediction_recycling:
544
+ self.prediction_recycling = PredictionContextRecycling(
545
+ d_model=config.d_model,
546
+ dropout=config.dropout,
547
+ )
548
+ else:
549
+ self.prediction_recycling = None
550
+
551
+ if config.use_router_coevolve:
552
+ self.router_coevolve = RouterExpertCoevolve(
553
+ d_model=config.d_model,
554
+ d_pair=config.d_pair,
555
+ n_experts=max(1, config.d_model // 192),
556
+ dropout=config.dropout,
557
+ )
558
+ else:
559
+ self.router_coevolve = None
560
+
561
+ # ================================================================
562
+ # Level 1 β€” Layer Recycling
563
+ # ================================================================
564
+
565
+ def recycle_layers(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
566
+ """Apply Level 1: inter-layer recycling."""
567
+ if self.layer_recycling is not None:
568
+ return self.layer_recycling(hidden_states)
569
+ return hidden_states
570
+
571
+ # ================================================================
572
+ # Level 2 β€” Bidirectional Token Update
573
+ # ================================================================
574
+
575
+ def bidirectional_token_update(self, x: torch.Tensor) -> torch.Tensor:
576
+ """Apply Level 2: bidirectional token update."""
577
+ if self.bidirectional_token is not None:
578
+ return self.bidirectional_token(x)
579
+ return x
580
+
581
+ # ================================================================
582
+ # Level 3 β€” Decoder ↔ Predict Feedback
583
+ # ================================================================
584
+
585
+ def apply_decoder_feedback(self, hidden_state: torch.Tensor, decoder_output: torch.Tensor) -> torch.Tensor:
586
+ """Apply Level 3: decoder-predict feedback.
587
+
588
+ AAM-specific: narrative output revises graph conditioning.
589
+ """
590
+ if self.decoder_feedback is not None:
591
+ return self.decoder_feedback(hidden_state, decoder_output)
592
+ return hidden_state
593
+
594
+ def decoder_predict_feedback(self, hidden_state: torch.Tensor, decoder_output: torch.Tensor) -> torch.Tensor:
595
+ """Convenience method for Level 3 (self-referential alias).
596
+
597
+ Same as :meth:`apply_decoder_feedback` β€” provided for
598
+ discoverability and symmetry with the module name.
599
+ """
600
+ return self.apply_decoder_feedback(hidden_state, decoder_output)
601
+
602
+ # ================================================================
603
+ # Level 4 β€” Prediction β†’ Context Recycling
604
+ # ================================================================
605
+
606
+ def apply_prediction_recycling(self, hidden_states: torch.Tensor, prediction_logits: torch.Tensor) -> torch.Tensor:
607
+ """Apply Level 4: prediction-context recycling.
608
+
609
+ AAM-specific: predicted narrative refines graph understanding.
610
+ """
611
+ if self.prediction_recycling is not None:
612
+ return self.prediction_recycling(hidden_states, prediction_logits)
613
+ return hidden_states
614
+
615
+ def prediction_context_recycling(self, hidden_states: torch.Tensor, prediction_logits: torch.Tensor) -> torch.Tensor:
616
+ """Convenience method for Level 4 (self-referential alias).
617
+
618
+ Same as :meth:`apply_prediction_recycling` β€” provided for
619
+ discoverability and symmetry with the module name.
620
+ """
621
+ return self.apply_prediction_recycling(hidden_states, prediction_logits)
622
+
623
+ # ================================================================
624
+ # Level 5 β€” Router-Expert Co-evolution
625
+ # ================================================================
626
+
627
+ def apply_router_coevolve(
628
+ self,
629
+ graph_repr: torch.Tensor,
630
+ narrative_repr: torch.Tensor,
631
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
632
+ """Apply Level 5: graph node ↔ sentence arrangement co-evolution.
633
+
634
+ AAM-specific: graph understanding and narrative output negotiate
635
+ through the co-evolve state, each adjusting based on the other.
636
+
637
+ Args:
638
+ graph_repr: Graph node representations ``(B, S_g, d_model)``.
639
+ narrative_repr: Narrative representations ``(B, S_n, d_model)``.
640
+
641
+ Returns:
642
+ Tuple of ``(updated_graph, updated_narrative)``.
643
+ """
644
+ if self.router_coevolve is not None:
645
+ return self.router_coevolve(graph_repr, narrative_repr)
646
+ return graph_repr, narrative_repr
647
+
648
+ def router_expert_coevolve(
649
+ self,
650
+ graph_repr: torch.Tensor,
651
+ narrative_repr: torch.Tensor,
652
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
653
+ """Convenience method for Level 5 (self-referential alias).
654
+
655
+ Same as :meth:`apply_router_coevolve` β€” named after the
656
+ Losion module for discoverability.
657
+
658
+ Args:
659
+ graph_repr: Graph node representations ``(B, S_g, d_model)``.
660
+ narrative_repr: Narrative representations ``(B, S_n, d_model)``.
661
+
662
+ Returns:
663
+ Tuple of ``(updated_graph, updated_narrative)``.
664
+ """
665
+ return self.apply_router_coevolve(graph_repr, narrative_repr)
666
+
667
+ # ================================================================
668
+ # Reset
669
+ # ================================================================
670
+
671
+ def reset(self) -> None:
672
+ """Reset all mutable state (buffers, counters).
673
+
674
+ Call this at the start of a new sequence or inference run to
675
+ clear the co-evolve state and routing adjustments from
676
+ previous inputs.
677
+ """
678
+ if self.router_coevolve is not None:
679
+ self.router_coevolve.coevolve_state.zero_()
680
+ self.router_coevolve.routing_adjustment.zero_()
681
+
682
+ # ================================================================
683
+ # Statistics
684
+ # ================================================================
685
+
686
+ def get_stats(self) -> Dict[str, object]:
687
+ """Return activation status for all Evoformer levels."""
688
+ return {
689
+ "level_1_layer_recycling": self.layer_recycling is not None,
690
+ "level_2_bidirectional_token": self.bidirectional_token is not None,
691
+ "level_3_decoder_feedback": self.decoder_feedback is not None,
692
+ "level_4_prediction_recycling": self.prediction_recycling is not None,
693
+ "level_5_router_coevolve": self.router_coevolve is not None,
694
+ "n_recycling_steps": self.config.n_recycling_steps,
695
+ "d_pair": self.config.d_pair if self.config.d_pair > 0 else self.config.d_model,
696
+ }