File size: 28,344 Bytes
46c5bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
"""AAM Diffusion LLM β€” Evoformer Feedback System

Adapted from Losion/AlphaFold2: iterative bidirectional feedback
at multiple architecture levels.

For AAM, the most relevant levels:
    Level 1 β€” Inter-Layer Recycling: Layer deep ↔ Layer shallow
    Level 2 β€” Bidirectional Token Update: Token old ↔ Token new
    Level 3 β€” Decoder ↔ Predict: Narrative output ↔ Graph conditioning
    Level 4 β€” Prediction β†’ Context: Predicted narrative refines graph understanding
    Level 5 β€” Router-Expert Co-evolution: Graph node ↔ Sentence arrangement

Core Principle: "Whenever there are two related representations, replace
one-way information flow with iterative bidirectional dialogue."

This is PERFECT for AAM's Predictive Coding:
    predict(X) β†’ observe(Y) β†’ belief_update(Ξ”)

Evoformer makes this bidirectional and iterative.

Level 5 (RouterExpertCoevolve) β€” AAM-specific adaptation:
    In Losion, this handles router ↔ MoE expert co-evolution.
    For AAM, this handles: graph node ↔ sentence arrangement co-evolution.
    The co-evolve state captures the "negotiation" between graph
    understanding and narrative output β€” each side adjusts based on
    the other's current state, creating an iterative dialogue where
    better graph understanding leads to better narrative, and better
    narrative feedback refines graph understanding.
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class EvoformerConfig:
    """Configuration for Evoformer Feedback System.

    Attributes:
        d_model: Model hidden dimension.
        n_recycling_steps: Number of recycling iterations.
        dropout: Dropout rate for all sub-modules.
        use_layer_recycling: Enable Level 1 (inter-layer recycling).
        use_token_recycling: Enable Level 2 (bidirectional token update).
        use_decoder_feedback: Enable Level 3 (decoder-predict feedback).
        use_prediction_recycling: Enable Level 4 (prediction-context recycling).
        use_router_coevolve: Enable Level 5 (router-expert co-evolution).
        d_pair: Pair representation dimension for co-evolution state.
            0 means use d_model.
        min_recycling_improvement: Minimum improvement threshold for recycling.
    """

    d_model: int = 768
    n_recycling_steps: int = 3
    dropout: float = 0.0
    use_layer_recycling: bool = True
    use_token_recycling: bool = True
    use_decoder_feedback: bool = True
    use_prediction_recycling: bool = True
    use_router_coevolve: bool = True
    d_pair: int = 0  # 0 = use d_model
    min_recycling_improvement: float = 1e-4


class LayerRecyclingBlock(nn.Module):
    """Level 1: Bidirectional feedback between deep and shallow layers.

    Losion v1.9.0 gradient-flow fix: deep layers also receive a small
    revision residual (0.05 multiplier) so that ``recycled[-1]`` carries
    gradient through the revision path back to all layer_recycling
    parameters.  Without this, deep layers get no revision and the
    gradient from the final output cannot flow back through the
    revision path.
    """

    def __init__(self, d_model: int, n_recycling_steps: int = 2, dropout: float = 0.0) -> None:
        super().__init__()
        self.d_model = d_model
        self.n_recycling_steps = n_recycling_steps

        self.shallow_query_proj = nn.Linear(d_model, d_model, bias=False)
        self.deep_key_proj = nn.Linear(d_model, d_model, bias=False)
        self.deep_value_proj = nn.Linear(d_model, d_model, bias=False)
        self.revision_proj = nn.Linear(d_model, d_model, bias=False)

        self.revision_gate = nn.Sequential(
            nn.Linear(d_model * 2, 1, bias=False),
            nn.Sigmoid(),
        )

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
        self.scale = math.sqrt(d_model)

        # Losion v1.9.0: deep-layer revision multiplier (small but nonzero
        # to maintain gradient flow through the revision path).
        self.deep_revision_multiplier: float = 0.05

    def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
        if len(hidden_states) < 2:
            return hidden_states

        n_layers = len(hidden_states)
        mid = n_layers // 2
        shallow_repr = torch.stack(hidden_states[:mid], dim=0).mean(dim=0)
        deep_repr = torch.stack(hidden_states[mid:], dim=0).mean(dim=0)

        q = self.shallow_query_proj(shallow_repr)
        k = self.deep_key_proj(deep_repr)
        v = self.deep_value_proj(deep_repr)

        k_mean = k.mean(dim=1, keepdim=True)
        v_mean = v.mean(dim=1, keepdim=True)

        scores = torch.matmul(q, k_mean.transpose(-2, -1)) / self.scale
        attn = F.softmax(scores, dim=-1)

        if self.dropout is not None:
            attn = self.dropout(attn)

        revision = torch.matmul(attn, v_mean)
        revision = self.revision_proj(revision)

        gate = self.revision_gate(torch.cat([shallow_repr, revision], dim=-1))
        revision = gate * revision

        revised = []
        for i, h in enumerate(hidden_states):
            if i < mid:
                revised.append(h + revision * (0.1 if i < mid // 2 else 0.2))
            else:
                # Losion v1.9.0 fix: deep layers receive a small revision
                # residual so gradient flows from recycled[-1] back through
                # the revision path to all layer_recycling parameters.
                revised.append(h + revision * self.deep_revision_multiplier)

        return revised


class BidirectionalTokenUpdate(nn.Module):
    """Level 2: Later tokens revise earlier token representations."""

    def __init__(self, d_model: int, n_heads: int = 8, dropout: float = 0.0) -> None:
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_kv = d_model // n_heads

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

        self.gate = nn.Sequential(
            nn.Linear(d_model, 1, bias=False),
            nn.Sigmoid(),
        )

        self.norm = nn.RMSNorm(d_model)
        self.dropout_mod = nn.Dropout(dropout) if dropout > 0 else None
        self.scale = math.sqrt(self.d_kv)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch, seq_len, _ = x.shape
        if seq_len <= 1:
            return x

        q = self.q_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2)
        k = self.k_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2)
        v = self.v_proj(x).view(batch, seq_len, self.n_heads, self.d_kv).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        attn = F.softmax(scores, dim=-1, dtype=torch.float32).to(x.dtype)

        if self.dropout_mod is not None:
            attn = self.dropout_mod(attn)

        backward_info = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch, seq_len, self.d_model)
        backward_info = self.out_proj(backward_info)

        gate = self.gate(x)
        revised = x + gate * backward_info
        revised = self.norm(revised)

        return revised


class DecoderPredictFeedback(nn.Module):
    """Level 3: Bidirectional feedback between decoder output and graph prediction.

    AAM-specific: narrative output revises graph conditioning.
    Predict v1 β†’ Decoder refine β†’ feedback β†’ Update v1 β†’ loop
    """

    def __init__(self, d_model: int, n_iterations: int = 2, dropout: float = 0.0) -> None:
        super().__init__()
        self.d_model = d_model
        self.n_iterations = n_iterations

        self.feedback_proj = nn.Sequential(
            nn.Linear(d_model, d_model, bias=False),
            nn.SiLU(),
            nn.Linear(d_model, d_model, bias=False),
        )

        self.feedback_gate = nn.Sequential(
            nn.Linear(d_model, 1, bias=False),
            nn.Sigmoid(),
        )

        self.norm = nn.RMSNorm(d_model)
        self.dropout_mod = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self, hidden_state: torch.Tensor, decoder_output: torch.Tensor) -> torch.Tensor:
        delta = decoder_output - hidden_state
        feedback = self.feedback_proj(delta)
        gate = self.feedback_gate(hidden_state)
        feedback = gate * feedback

        if self.dropout_mod is not None:
            feedback = self.dropout_mod(feedback)

        updated = self.norm(hidden_state + feedback)
        return updated


class PredictionContextRecycling(nn.Module):
    """Level 4: Predicted narrative revises graph understanding.

    AAM-specific: the generated narrative can refine how we understand
    the graph, creating a feedback loop between output and input.
    """

    def __init__(self, d_model: int, dropout: float = 0.0) -> None:
        super().__init__()
        self.d_model = d_model

        self.pred_proj = nn.Linear(d_model, d_model, bias=False)
        self.context_query = nn.Linear(d_model, d_model, bias=False)
        self.pred_key = nn.Linear(d_model, d_model, bias=False)
        self.pred_value = nn.Linear(d_model, d_model, bias=False)
        self.revision_proj = nn.Linear(d_model, d_model, bias=False)
        self.revision_gate = nn.Sequential(
            nn.Linear(d_model, 1, bias=False),
            nn.Sigmoid(),
        )

        self.norm = nn.RMSNorm(d_model)
        self.dropout_mod = nn.Dropout(dropout) if dropout > 0 else None
        self.scale = math.sqrt(d_model)

    def forward(self, hidden_states: torch.Tensor, prediction_logits: torch.Tensor) -> torch.Tensor:
        batch, seq_len, _ = hidden_states.shape

        if prediction_logits.shape[-1] != self.d_model:
            pred_repr = self.pred_proj(prediction_logits[:, -1:, :self.d_model]
                                        if prediction_logits.dim() == 3
                                        else prediction_logits.unsqueeze(1))
        else:
            pred_repr = prediction_logits[:, -1:, :] if prediction_logits.dim() == 3 else prediction_logits.unsqueeze(1)

        q = self.context_query(hidden_states)
        k = self.pred_key(pred_repr)
        v = self.pred_value(pred_repr)

        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        attn = F.softmax(scores, dim=-2)

        if self.dropout_mod is not None:
            attn = self.dropout_mod(attn)

        revision = torch.matmul(attn, v)
        revision = self.revision_proj(revision)

        gate = self.revision_gate(hidden_states)
        revised = hidden_states + gate * revision
        revised = self.norm(revised)

        return revised


class RouterExpertCoevolve(nn.Module):
    """Level 5: Graph node ↔ sentence arrangement co-evolution.

    Adapted from Losion's RouterExpertCoevolve (router ↔ MoE expert
    co-evolution).  In Losion, the router distributes tokens to MoE
    experts, and expert outputs refine the router's decisions β€” a
    bidirectional negotiation.

    For AAM, the co-evolution is between:
        - Graph nodes: evidence from RSVS graph (the "router" side β€”
          which evidence to attend to)
        - Sentence arrangement: narrative output (the "expert" side β€”
          how to express the evidence in natural language)

    The co-evolve state captures the "negotiation" between graph
    understanding and narrative output: each side adjusts based on
    the other's current state, creating an iterative dialogue where
    better graph understanding leads to better narrative, and better
    narrative feedback refines graph understanding.

    Key design (from Losion v1.9.0):
        - ``update_state()`` returns a **differentiable** tensor so
          gradient flows through the revision path to all
          RouterExpertCoevolve parameters.
        - The internal buffer is updated with **detached** values to
          prevent unbounded gradient accumulation across training steps.

    Args:
        d_model: Model hidden dimension.
        d_pair: Pair (co-evolution state) dimension.  0 means use d_model.
        n_experts: Number of routing experts (graph attention heads).
        dropout: Dropout rate.
    """

    def __init__(
        self,
        d_model: int,
        d_pair: int = 0,
        n_experts: int = 4,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.d_pair = d_pair if d_pair > 0 else d_model
        self.n_experts = n_experts

        # ── Graph (router) side β€” projects graph representations ──
        self.graph_router = nn.Linear(d_model, n_experts, bias=False)
        self.graph_adjust_proj = nn.Linear(d_model, self.d_pair, bias=False)

        # ── Narrative (expert) side β€” projects narrative representations ──
        self.narrative_adjust_proj = nn.Linear(d_model, self.d_pair, bias=False)

        # ── Co-evolution gate: learns how much each side influences ──
        # the negotiation state
        self.coevolve_gate = nn.Sequential(
            nn.Linear(self.d_pair * 2, self.d_pair, bias=False),
            nn.SiLU(),
            nn.Linear(self.d_pair, self.d_pair, bias=False),
            nn.Sigmoid(),
        )

        # ── Output projections back to d_model ──
        self.graph_out_proj = nn.Linear(self.d_pair, d_model, bias=False)
        self.narrative_out_proj = nn.Linear(self.d_pair, d_model, bias=False)

        # ── Normalization ──
        self.norm_graph = nn.RMSNorm(d_model)
        self.norm_narrative = nn.RMSNorm(d_model)

        self.dropout_mod = nn.Dropout(dropout) if dropout > 0 else None

        # ── Buffers (detached from computation graph) ──
        # Co-evolve state: the shared negotiation state between
        # graph understanding and narrative output.
        self.register_buffer("coevolve_state", torch.zeros(1, 1, self.d_pair))

        # Routing adjustment: influences which graph nodes (evidence)
        # receive more attention β€” the graph-side "opinion".
        self.register_buffer("routing_adjustment", torch.zeros(1, self.n_experts))

    def get_routing_adjustment(self) -> torch.Tensor:
        """Return routing adjustment based on current co-evolve state.

        The adjustment influences which graph nodes (evidence) receive
        more attention β€” it is the graph-side "opinion" derived from
        the current negotiation state between graph understanding and
        narrative output.

        Returns:
            Tensor of shape ``(1, n_experts)`` with routing adjustments.
        """
        # Compute fresh adjustment from the current co-evolve state
        state_flat = self.coevolve_state.squeeze(1)  # (1, d_pair)
        adj = self.graph_router(state_flat)  # (1, n_experts)
        return adj + self.routing_adjustment

    def update_state(
        self,
        graph_repr: torch.Tensor,
        narrative_repr: torch.Tensor,
    ) -> torch.Tensor:
        """Update co-evolve state; return differentiable tensor for gradient flow.

        Losion v1.9.0 pattern: the returned tensor is differentiable,
        so gradient flows back through the revision path to all
        RouterExpertCoevolve parameters.  However, the buffer is
        updated with detached values to prevent unbounded gradient
        accumulation across training steps.

        This captures the "negotiation" between:
        - Graph understanding: which evidence nodes are most relevant
        - Narrative output: how the evidence is being expressed

        Each side adjusts the co-evolve state based on its current
        representation, and the gate learns the optimal balance.

        Args:
            graph_repr: Graph node representations ``(B, S_g, d_model)``.
                Evidence from RSVS graph.
            narrative_repr: Narrative representations ``(B, S_n, d_model)``.
                Sentence arrangement output.

        Returns:
            Differentiable co-evolve state of shape ``(B, 1, d_pair)``.
        """
        # Project both sides into the co-evolution space
        g_adj = self.graph_adjust_proj(graph_repr)       # (B, S_g, d_pair)
        n_adj = self.narrative_adjust_proj(narrative_repr)  # (B, S_n, d_pair)

        # Aggregate across sequence dimension (mean pooling)
        g_pool = g_adj.mean(dim=1, keepdim=True)   # (B, 1, d_pair)
        n_pool = n_adj.mean(dim=1, keepdim=True)   # (B, 1, d_pair)

        # Co-evolution gate: learns the negotiation balance between
        # graph understanding and narrative output
        combined = torch.cat([g_pool, n_pool], dim=-1)  # (B, 1, d_pair*2)
        gate = self.coevolve_gate(combined)               # (B, 1, d_pair)

        # New state = gated negotiation between graph and narrative,
        # blended with the previous state for stability
        new_state = gate * (g_pool + n_pool) + (1.0 - gate) * self.coevolve_state

        # IMPORTANT (Losion v1.9.0): Return differentiable version so
        # gradient flows through new_state back to all
        # RouterExpertCoevolve parameters.
        differentiable_state = new_state

        # Update buffer detached β€” prevents cross-step gradient
        # accumulation while keeping the state current for the next
        # forward pass.
        with torch.no_grad():
            self.coevolve_state.copy_(new_state.detach())
            # Also update routing adjustment based on new state
            adj = self.graph_router(new_state.squeeze(1))  # (B, n_experts)
            self.routing_adjustment.copy_(adj.detach().mean(dim=0, keepdim=True))

        return differentiable_state

    def forward(
        self,
        graph_repr: torch.Tensor,
        narrative_repr: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Co-evolve graph and narrative representations.

        This is the main entry point.  It updates the co-evolve state
        (capturing the negotiation between graph understanding and
        narrative output) and applies the resulting adjustments to
        both representations.

        The co-evolution works as follows:
        1. Graph and narrative representations are projected into a
           shared co-evolution space.
        2. A gated negotiation combines both perspectives.
        3. The resulting state adjusts both graph understanding
           (which evidence to attend to) and narrative output
           (how to express the evidence).

        Args:
            graph_repr: Graph node representations ``(B, S_g, d_model)``.
                Evidence from RSVS graph.
            narrative_repr: Narrative representations ``(B, S_n, d_model)``.
                Sentence arrangement output.

        Returns:
            Tuple of ``(updated_graph, updated_narrative)`` β€” both
            revised through the co-evolution negotiation.
        """
        # Step 1: Update co-evolve state, get differentiable state
        # (gradient flows through this to all RouterExpertCoevolve params)
        coevolve = self.update_state(graph_repr, narrative_repr)  # (B, 1, d_pair)

        # Step 2: Expand to match input sequence lengths
        coevolve_graph = coevolve.expand(-1, graph_repr.shape[1], -1)          # (B, S_g, d_pair)
        coevolve_narrative = coevolve.expand(-1, narrative_repr.shape[1], -1)  # (B, S_n, d_pair)

        # Step 3: Project back to d_model
        graph_adj = self.graph_out_proj(coevolve_graph)            # (B, S_g, d_model)
        narrative_adj = self.narrative_out_proj(coevolve_narrative)  # (B, S_n, d_model)

        # Step 4: Apply dropout
        if self.dropout_mod is not None:
            graph_adj = self.dropout_mod(graph_adj)
            narrative_adj = self.dropout_mod(narrative_adj)

        # Step 5: Residual connection + normalization
        updated_graph = self.norm_graph(graph_repr + graph_adj)
        updated_narrative = self.norm_narrative(narrative_repr + narrative_adj)

        return updated_graph, updated_narrative


class EvoformerManager(nn.Module):
    """Manages Evoformer feedback levels for AAM Diffusion LLM.

    Levels:
        1. LayerRecyclingBlock β€” inter-layer bidirectional feedback
        2. BidirectionalTokenUpdate β€” token-level bidirectional update
        3. DecoderPredictFeedback β€” decoder ↔ graph prediction feedback
        4. PredictionContextRecycling β€” prediction β†’ context recycling
        5. RouterExpertCoevolve β€” graph node ↔ sentence arrangement co-evolution
    """

    def __init__(self, config: EvoformerConfig) -> None:
        super().__init__()
        self.config = config

        if config.use_layer_recycling:
            self.layer_recycling = LayerRecyclingBlock(
                d_model=config.d_model,
                n_recycling_steps=config.n_recycling_steps,
                dropout=config.dropout,
            )
        else:
            self.layer_recycling = None

        if config.use_token_recycling:
            self.bidirectional_token = BidirectionalTokenUpdate(
                d_model=config.d_model,
                n_heads=max(1, config.d_model // 128),
                dropout=config.dropout,
            )
        else:
            self.bidirectional_token = None

        if config.use_decoder_feedback:
            self.decoder_feedback = DecoderPredictFeedback(
                d_model=config.d_model,
                n_iterations=config.n_recycling_steps,
                dropout=config.dropout,
            )
        else:
            self.decoder_feedback = None

        if config.use_prediction_recycling:
            self.prediction_recycling = PredictionContextRecycling(
                d_model=config.d_model,
                dropout=config.dropout,
            )
        else:
            self.prediction_recycling = None

        if config.use_router_coevolve:
            self.router_coevolve = RouterExpertCoevolve(
                d_model=config.d_model,
                d_pair=config.d_pair,
                n_experts=max(1, config.d_model // 192),
                dropout=config.dropout,
            )
        else:
            self.router_coevolve = None

    # ================================================================
    # Level 1 β€” Layer Recycling
    # ================================================================

    def recycle_layers(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
        """Apply Level 1: inter-layer recycling."""
        if self.layer_recycling is not None:
            return self.layer_recycling(hidden_states)
        return hidden_states

    # ================================================================
    # Level 2 β€” Bidirectional Token Update
    # ================================================================

    def bidirectional_token_update(self, x: torch.Tensor) -> torch.Tensor:
        """Apply Level 2: bidirectional token update."""
        if self.bidirectional_token is not None:
            return self.bidirectional_token(x)
        return x

    # ================================================================
    # Level 3 β€” Decoder ↔ Predict Feedback
    # ================================================================

    def apply_decoder_feedback(self, hidden_state: torch.Tensor, decoder_output: torch.Tensor) -> torch.Tensor:
        """Apply Level 3: decoder-predict feedback.

        AAM-specific: narrative output revises graph conditioning.
        """
        if self.decoder_feedback is not None:
            return self.decoder_feedback(hidden_state, decoder_output)
        return hidden_state

    def decoder_predict_feedback(self, hidden_state: torch.Tensor, decoder_output: torch.Tensor) -> torch.Tensor:
        """Convenience method for Level 3 (self-referential alias).

        Same as :meth:`apply_decoder_feedback` β€” provided for
        discoverability and symmetry with the module name.
        """
        return self.apply_decoder_feedback(hidden_state, decoder_output)

    # ================================================================
    # Level 4 β€” Prediction β†’ Context Recycling
    # ================================================================

    def apply_prediction_recycling(self, hidden_states: torch.Tensor, prediction_logits: torch.Tensor) -> torch.Tensor:
        """Apply Level 4: prediction-context recycling.

        AAM-specific: predicted narrative refines graph understanding.
        """
        if self.prediction_recycling is not None:
            return self.prediction_recycling(hidden_states, prediction_logits)
        return hidden_states

    def prediction_context_recycling(self, hidden_states: torch.Tensor, prediction_logits: torch.Tensor) -> torch.Tensor:
        """Convenience method for Level 4 (self-referential alias).

        Same as :meth:`apply_prediction_recycling` β€” provided for
        discoverability and symmetry with the module name.
        """
        return self.apply_prediction_recycling(hidden_states, prediction_logits)

    # ================================================================
    # Level 5 β€” Router-Expert Co-evolution
    # ================================================================

    def apply_router_coevolve(
        self,
        graph_repr: torch.Tensor,
        narrative_repr: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply Level 5: graph node ↔ sentence arrangement co-evolution.

        AAM-specific: graph understanding and narrative output negotiate
        through the co-evolve state, each adjusting based on the other.

        Args:
            graph_repr: Graph node representations ``(B, S_g, d_model)``.
            narrative_repr: Narrative representations ``(B, S_n, d_model)``.

        Returns:
            Tuple of ``(updated_graph, updated_narrative)``.
        """
        if self.router_coevolve is not None:
            return self.router_coevolve(graph_repr, narrative_repr)
        return graph_repr, narrative_repr

    def router_expert_coevolve(
        self,
        graph_repr: torch.Tensor,
        narrative_repr: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Convenience method for Level 5 (self-referential alias).

        Same as :meth:`apply_router_coevolve` β€” named after the
        Losion module for discoverability.

        Args:
            graph_repr: Graph node representations ``(B, S_g, d_model)``.
            narrative_repr: Narrative representations ``(B, S_n, d_model)``.

        Returns:
            Tuple of ``(updated_graph, updated_narrative)``.
        """
        return self.apply_router_coevolve(graph_repr, narrative_repr)

    # ================================================================
    # Reset
    # ================================================================

    def reset(self) -> None:
        """Reset all mutable state (buffers, counters).

        Call this at the start of a new sequence or inference run to
        clear the co-evolve state and routing adjustments from
        previous inputs.
        """
        if self.router_coevolve is not None:
            self.router_coevolve.coevolve_state.zero_()
            self.router_coevolve.routing_adjustment.zero_()

    # ================================================================
    # Statistics
    # ================================================================

    def get_stats(self) -> Dict[str, object]:
        """Return activation status for all Evoformer levels."""
        return {
            "level_1_layer_recycling": self.layer_recycling is not None,
            "level_2_bidirectional_token": self.bidirectional_token is not None,
            "level_3_decoder_feedback": self.decoder_feedback is not None,
            "level_4_prediction_recycling": self.prediction_recycling is not None,
            "level_5_router_coevolve": self.router_coevolve is not None,
            "n_recycling_steps": self.config.n_recycling_steps,
            "d_pair": self.config.d_pair if self.config.d_pair > 0 else self.config.d_model,
        }