Wolfvin commited on
Commit
3ddd8b6
·
verified ·
1 Parent(s): 921c3d4

Upload diffusion_llm/model/aam_diffusion_model.py with huggingface_hub

Browse files
diffusion_llm/model/aam_diffusion_model.py CHANGED
@@ -1,16 +1,22 @@
1
  """
2
- AAM Diffusion LLM — Complete Model
3
 
4
  Combines the Diffusion Transformer, Graph Encoder, and Noise Scheduler
5
  into a single, unified model for training and inference.
6
 
7
- This is the "body" of AAM — the specialized sentence composer that
8
- takes graph conditioning as input and produces coherent narratives
9
- through iterative denoising.
 
 
 
 
 
 
10
 
11
  Architecture:
12
  ┌──────────────────────────────────────────────────┐
13
- │ AAM Diffusion Model (The Body)
14
  │ │
15
  │ Input: │
16
  │ - Token IDs (text) │
@@ -23,16 +29,22 @@ Architecture:
23
  │ 3. Add noise: x_t = schedule.add_noise(x_0, t) │
24
  │ 4. Encode graph conditioning │
25
  │ 5. Predict noise: eps = transformer(x_t, t, c) │
26
- │ 6. Compute loss: L = MSE(eps, eps_target)
 
27
  │ │
28
- │ Inference Process:
 
 
 
 
 
 
29
  │ 1. Start from pure noise x_T │
30
  │ 2. Encode graph conditioning │
31
  │ 3. For t = T, T-1, ..., 1: │
32
  │ a. Predict noise: eps = transformer(x_t, t) │
33
  │ b. Denoise: x_{t-1} = schedule.step(eps) │
34
  │ 4. Decode final x_0 → text tokens │
35
- │ 5. Detokenize → natural language narrative │
36
  │ │
37
  │ Key Constraint: │
38
  │ The model CANNOT generate information not │
@@ -45,14 +57,17 @@ Architecture:
45
  └──────────────────────────────────────────────────┘
46
 
47
  Analogi: Ini adalah seluruh "tubuh" Jin Soun — bukan hanya
48
- ototnya (transformer), tapi juga sistem saraf (graph encoder)
49
- dan kemampuan untuk memperbaiki diri (diffusion denoising).
 
 
 
50
  """
51
 
52
  from __future__ import annotations
53
 
54
  import logging
55
- from typing import Optional
56
 
57
  import torch
58
  import torch.nn as nn
@@ -66,12 +81,18 @@ logger = logging.getLogger(__name__)
66
 
67
 
68
  class AamDiffusionModel(nn.Module):
69
- """Complete AAM Diffusion LLM model.
70
 
71
  Combines:
72
  - DiffusionTransformer: Core denoising network
73
  - GraphConditioningEncoder: Encodes graph structure for conditioning
74
  - NoiseScheduler: Manages the diffusion process
 
 
 
 
 
 
75
 
76
  This model is designed to be trained on Graph→Narrative pairs,
77
  where the graph data comes from the RSVS Knowledge Graph and
@@ -85,7 +106,20 @@ class AamDiffusionModel(nn.Module):
85
  super().__init__()
86
  self.config = config
87
 
88
- # Core components
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  self.noise_scheduler = NoiseScheduler(
90
  n_timesteps=config.diffusion.n_timesteps,
91
  schedule_type=config.diffusion.schedule_type,
@@ -103,29 +137,146 @@ class AamDiffusionModel(nn.Module):
103
 
104
  self.transformer = DiffusionTransformer(config.model)
105
 
106
- # Token-to-embedding projection (shared with transformer)
107
- # The transformer's token_embedding is used for both
108
- # encoding input text and decoding output text
109
-
110
- # Output head: project from d_model to vocab_size
111
- self.lm_head = nn.Linear(
112
- config.model.d_model, config.model.vocab_size, bias=False
113
- )
114
-
115
- # Tie weights between token embedding and LM head
116
- # This is standard practice and reduces parameter count
117
- self.lm_head.weight = self.transformer.token_embedding.weight
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  # EMA model (for inference, updated during training)
 
120
  self._ema_model: Optional[AamDiffusionModel] = None
121
  self._ema_decay = config.training.ema_decay
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  logger.info(
124
- "AamDiffusionModel initialized: %s params, %s",
125
  self._format_params(self.get_num_params()),
126
  config.model_name,
 
127
  )
128
 
 
 
 
 
129
  def forward(
130
  self,
131
  token_ids: torch.Tensor,
@@ -141,14 +292,15 @@ class AamDiffusionModel(nn.Module):
141
  reasoning_ids: Optional[torch.Tensor] = None,
142
  reasoning_confidence: Optional[torch.Tensor] = None,
143
  source_trust: Optional[torch.Tensor] = None,
144
- ) -> torch.Tensor:
145
  """Forward pass for training.
146
 
147
  1. Get clean embeddings from token IDs
148
  2. Add noise at the given timestep
149
  3. Encode graph conditioning
150
  4. Predict noise via transformer
151
- 5. Return predicted noise (loss computed externally)
 
152
 
153
  Args:
154
  token_ids: Clean text token IDs, shape (batch, seq_len).
@@ -166,7 +318,7 @@ class AamDiffusionModel(nn.Module):
166
  source_trust: Source trust score.
167
 
168
  Returns:
169
- Predicted noise tensor of shape (batch, seq_len, d_model).
170
  """
171
  # Step 1: Get clean embeddings (x_0)
172
  x_0 = self.transformer.token_embedding(token_ids)
@@ -196,6 +348,17 @@ class AamDiffusionModel(nn.Module):
196
  graph_keys = graph_cond.get("keys")
197
  graph_values = graph_cond.get("values")
198
 
 
 
 
 
 
 
 
 
 
 
 
199
  # Step 4: Predict noise via transformer
200
  predicted = self.transformer(
201
  x_t=x_t,
@@ -204,8 +367,37 @@ class AamDiffusionModel(nn.Module):
204
  graph_values=graph_values,
205
  )
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  return predicted, noise
208
 
 
 
 
 
209
  def compute_loss(
210
  self,
211
  predicted: torch.Tensor,
@@ -299,6 +491,10 @@ class AamDiffusionModel(nn.Module):
299
  weight = weight.unsqueeze(-1).expand_as(loss)
300
  return loss * weight
301
 
 
 
 
 
302
  @torch.no_grad()
303
  def sample(
304
  self,
@@ -307,18 +503,28 @@ class AamDiffusionModel(nn.Module):
307
  method: str = "ddim",
308
  shape: Optional[tuple[int, ...]] = None,
309
  device: Optional[torch.device] = None,
 
310
  ) -> torch.Tensor:
311
  """Generate samples via iterative denoising.
312
 
313
- This is the INFERENCE method start from pure noise and
314
- iteratively denoise to produce coherent text embeddings.
 
 
 
 
 
 
 
315
 
316
  Args:
317
  graph_cond: Graph conditioning dict from GraphConditioningEncoder.
318
  n_steps: Number of denoising steps. Uses config if None.
319
- method: Sampling method ('ddpm' or 'ddim').
 
320
  shape: Shape of the output (batch, seq_len, d_model).
321
  device: Device to generate on.
 
322
 
323
  Returns:
324
  Denoised embeddings of shape (batch, seq_len, d_model).
@@ -330,13 +536,190 @@ class AamDiffusionModel(nn.Module):
330
  if shape is None:
331
  shape = (1, self.config.model.max_seq_len, self.config.model.d_model)
332
 
333
- # Start from pure noise
334
- x = torch.randn(shape, device=device)
335
-
336
  # Get graph conditioning
337
  graph_keys = graph_cond.get("keys")
338
  graph_values = graph_cond.get("values")
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  if method == "ddpm":
341
  # Full DDPM sampling
342
  for t in reversed(range(self.config.diffusion.n_timesteps)):
@@ -346,6 +729,11 @@ class AamDiffusionModel(nn.Module):
346
  graph_keys=graph_keys,
347
  graph_values=graph_values,
348
  )
 
 
 
 
 
349
  x = self.noise_scheduler.step_ddpm(predicted, x, t_tensor)
350
 
351
  elif method == "ddim":
@@ -361,33 +749,61 @@ class AamDiffusionModel(nn.Module):
361
  graph_keys=graph_keys,
362
  graph_values=graph_values,
363
  )
 
 
 
 
 
364
  x = self.noise_scheduler.step_ddim(
365
  predicted, x, t, t_prev,
366
  eta=self.config.diffusion.eta_ddim,
367
  )
 
 
 
 
 
368
 
369
  return x
370
 
 
 
 
 
371
  def embeddings_to_tokens(
372
  self,
373
  embeddings: torch.Tensor,
374
  temperature: float = 1.0,
375
  top_k: int = 50,
 
376
  ) -> torch.Tensor:
377
  """Convert continuous embeddings to discrete token IDs.
378
 
379
  This is the final step of generation — project embeddings
380
  to vocabulary logits and sample tokens.
381
 
 
 
 
 
382
  Args:
383
  embeddings: Denoised embeddings of shape (batch, seq_len, d_model).
384
  temperature: Sampling temperature.
385
  top_k: Top-k sampling cutoff.
 
386
 
387
  Returns:
388
  Token IDs of shape (batch, seq_len).
389
  """
390
- logits = self.lm_head(embeddings) / temperature
 
 
 
 
 
 
 
 
391
 
392
  # Top-k sampling
393
  if top_k > 0:
@@ -400,11 +816,104 @@ class AamDiffusionModel(nn.Module):
400
  -1, sampled_indices.unsqueeze(-1)
401
  ).squeeze(-1)
402
  else:
403
- probs = torch.softmax(logits, dim=-1)
404
  token_ids = torch.argmax(logits, dim=-1)
405
 
406
  return token_ids
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  def get_num_params(self) -> int:
409
  """Get total number of parameters."""
410
  return sum(p.numel() for p in self.parameters())
@@ -436,6 +945,10 @@ class AamDiffusionModel(nn.Module):
436
  def load(cls, path: str, device: str = "cpu") -> AamDiffusionModel:
437
  """Load model from checkpoint.
438
 
 
 
 
 
439
  Args:
440
  path: Checkpoint file path.
441
  device: Device to load to.
@@ -468,8 +981,50 @@ class AamDiffusionModel(nn.Module):
468
  logger.warning("Could not reconstruct config from checkpoint, using defaults")
469
  else:
470
  config = config_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  model = cls(config)
472
- model.load_state_dict(checkpoint["model_state_dict"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  model.to(device)
474
  logger.info("Model loaded from %s", path)
475
  return model
 
1
  """
2
+ AAM Diffusion LLM — Complete Model (v2.0)
3
 
4
  Combines the Diffusion Transformer, Graph Encoder, and Noise Scheduler
5
  into a single, unified model for training and inference.
6
 
7
+ v2.0 Upgrades:
8
+ - ContinuousOutputHead (Anchored Decoder) replaces lm_head for
9
+ 2-3 step refinement instead of 50-step DDPM/DDIM
10
+ - EvoformerManager for iterative bidirectional feedback
11
+ - DualMemorySystem for long narrative generation
12
+ - ThinkingToggle for adaptive compute (thinking vs non-thinking)
13
+ - FlowMatchingDecoder as alternative sampling method
14
+ - MCTSReasoner for complex reasoning tasks
15
+ - Full backward compatibility (use_anchored_decoder=False)
16
 
17
  Architecture:
18
  ┌──────────────────────────────────────────────────┐
19
+ │ AAM Diffusion Model v2.0 (The Body)
20
  │ │
21
  │ Input: │
22
  │ - Token IDs (text) │
 
29
  │ 3. Add noise: x_t = schedule.add_noise(x_0, t) │
30
  │ 4. Encode graph conditioning │
31
  │ 5. Predict noise: eps = transformer(x_t, t, c) │
32
+ │ 6. [Optional] Evoformer bidirectional feedback
33
+ │ 7. Compute loss: L = MSE(eps, eps_target) │
34
  │ │
35
+ │ Inference Process (v2.0 Anchored):
36
+ │ 1. Encode graph conditioning │
37
+ │ 2. Transformer produces initial prediction │
38
+ │ 3. Anchored Decoder refines in 2-3 steps │
39
+ │ 4. Convert to tokens via ContinuousOutputHead │
40
+ │ │
41
+ │ Inference Process (Legacy DDPM/DDIM): │
42
  │ 1. Start from pure noise x_T │
43
  │ 2. Encode graph conditioning │
44
  │ 3. For t = T, T-1, ..., 1: │
45
  │ a. Predict noise: eps = transformer(x_t, t) │
46
  │ b. Denoise: x_{t-1} = schedule.step(eps) │
47
  │ 4. Decode final x_0 → text tokens │
 
48
  │ │
49
  │ Key Constraint: │
50
  │ The model CANNOT generate information not │
 
57
  └──────────────────────────────────────────────────┘
58
 
59
  Analogi: Ini adalah seluruh "tubuh" Jin Soun — bukan hanya
60
+ ototnya (transformer), tapi juga sistem saraf (graph encoder),
61
+ kemampuan untuk memperbaiki diri (diffusion denoising), dan
62
+ di v2.0: pikiran sadar (Evoformer), ingatan jangka panjang
63
+ (DualMemory), kemampuan berpikir adaptif (ThinkingToggle),
64
+ dan penalaran mendalam (MCTS).
65
  """
66
 
67
  from __future__ import annotations
68
 
69
  import logging
70
+ from typing import Any, Dict, Optional
71
 
72
  import torch
73
  import torch.nn as nn
 
81
 
82
 
83
  class AamDiffusionModel(nn.Module):
84
+ """Complete AAM Diffusion LLM model (v2.0).
85
 
86
  Combines:
87
  - DiffusionTransformer: Core denoising network
88
  - GraphConditioningEncoder: Encodes graph structure for conditioning
89
  - NoiseScheduler: Manages the diffusion process
90
+ - [v2.0] ContinuousOutputHead: Anchored decoder for 2-3 step refinement
91
+ - [v2.0] EvoformerManager: Iterative bidirectional feedback
92
+ - [v2.0] DualMemorySystem: Working + long-term memory for narratives
93
+ - [v2.0] ThinkingToggle: Adaptive compute based on input complexity
94
+ - [v2.0] FlowMatchingDecoder: Alternative velocity-based sampling
95
+ - [v2.0] MCTSReasoner: Tree search for complex reasoning
96
 
97
  This model is designed to be trained on Graph→Narrative pairs,
98
  where the graph data comes from the RSVS Knowledge Graph and
 
106
  super().__init__()
107
  self.config = config
108
 
109
+ # ----------------------------------------------------------------
110
+ # Feature flags — use getattr for backward compatibility so old
111
+ # configs without the new fields still work.
112
+ # ----------------------------------------------------------------
113
+ self.use_anchored_decoder = getattr(config, "use_anchored_decoder", False)
114
+ self.use_evoformer = getattr(config, "use_evoformer", False)
115
+ self.use_dual_memory = getattr(config, "use_dual_memory", False)
116
+ self.use_thinking_toggle = getattr(config, "use_thinking_toggle", False)
117
+ self.use_flow_matching = getattr(config, "use_flow_matching", False)
118
+ self.use_mcts = getattr(config, "use_mcts", False)
119
+
120
+ # ----------------------------------------------------------------
121
+ # Core components (always present)
122
+ # ----------------------------------------------------------------
123
  self.noise_scheduler = NoiseScheduler(
124
  n_timesteps=config.diffusion.n_timesteps,
125
  schedule_type=config.diffusion.schedule_type,
 
137
 
138
  self.transformer = DiffusionTransformer(config.model)
139
 
140
+ # ----------------------------------------------------------------
141
+ # Output head v2.0 ContinuousOutputHead or legacy lm_head
142
+ # ----------------------------------------------------------------
143
+ if self.use_anchored_decoder:
144
+ from diffusion_llm.model.anchored_decoder import (
145
+ ContinuousOutputHead,
146
+ AnchoredDecoderConfig,
147
+ )
148
+
149
+ decoder_config = getattr(config, "anchored_decoder", None)
150
+ if decoder_config is None:
151
+ decoder_config = AnchoredDecoderConfig(
152
+ d_model=config.model.d_model,
153
+ d_vocab=config.model.vocab_size,
154
+ )
155
+ self.output_head = ContinuousOutputHead(
156
+ d_model=config.model.d_model,
157
+ d_vocab=config.model.vocab_size,
158
+ decoder_config=decoder_config,
159
+ )
160
+ else:
161
+ # Legacy: simple linear head with tied weights
162
+ self.lm_head = nn.Linear(
163
+ config.model.d_model, config.model.vocab_size, bias=False
164
+ )
165
+ self.lm_head.weight = self.transformer.token_embedding.weight
166
+
167
+ # ----------------------------------------------------------------
168
+ # Optional v2.0 modules — lazy imports
169
+ # ----------------------------------------------------------------
170
+ if self.use_evoformer:
171
+ from diffusion_llm.model.evoformer import EvoformerManager, EvoformerConfig
172
+
173
+ evoformer_config = getattr(config, "evoformer", None)
174
+ if evoformer_config is None:
175
+ evoformer_config = EvoformerConfig(d_model=config.model.d_model)
176
+ else:
177
+ # Sync d_model with the model's actual d_model
178
+ evoformer_config.d_model = config.model.d_model
179
+ self.evoformer = EvoformerManager(evoformer_config)
180
+
181
+ if self.use_dual_memory:
182
+ from diffusion_llm.model.dual_memory import (
183
+ DualMemorySystem,
184
+ DualMemoryConfig,
185
+ )
186
+
187
+ dual_memory_config = getattr(config, "dual_memory", None)
188
+ if dual_memory_config is None:
189
+ dual_memory_config = DualMemoryConfig(d_model=config.model.d_model)
190
+ else:
191
+ # Sync d_model with the model's actual d_model
192
+ dual_memory_config.d_model = config.model.d_model
193
+ self.dual_memory = DualMemorySystem(dual_memory_config)
194
+
195
+ if self.use_thinking_toggle:
196
+ from diffusion_llm.model.thinking_toggle import (
197
+ ThinkingToggle,
198
+ ThinkingMode,
199
+ )
200
+
201
+ thinking_config = getattr(config, "thinking_toggle", None)
202
+ d_thinking = (
203
+ thinking_config.d_model
204
+ if thinking_config is not None
205
+ else config.model.d_model
206
+ )
207
+ threshold = (
208
+ thinking_config.threshold
209
+ if thinking_config is not None
210
+ else 0.5
211
+ )
212
+ self.thinking_toggle = ThinkingToggle(d_thinking, threshold)
213
+ # Re-export for external use
214
+ self.ThinkingMode = ThinkingMode
215
+
216
+ if self.use_flow_matching:
217
+ from diffusion_llm.model.flow_matching import FlowMatchingDecoder
218
+
219
+ flow_config = getattr(config, "flow_matching", None)
220
+ fm_d_model = (
221
+ flow_config.d_model
222
+ if flow_config is not None
223
+ else config.model.d_model
224
+ )
225
+ fm_d_vocab = (
226
+ flow_config.d_vocab
227
+ if flow_config is not None
228
+ else config.model.vocab_size
229
+ )
230
+ fm_num_steps = (
231
+ flow_config.num_steps if flow_config is not None else 3
232
+ )
233
+ self.flow_matching_decoder = FlowMatchingDecoder(
234
+ fm_d_model, fm_d_vocab, fm_num_steps
235
+ )
236
+
237
+ if self.use_mcts:
238
+ from diffusion_llm.model.mcts import MCTSReasoner, MCTSConfig
239
+
240
+ mcts_config = getattr(config, "mcts", None)
241
+ if mcts_config is None:
242
+ mcts_config = MCTSConfig()
243
+ self.mcts_reasoner = MCTSReasoner(
244
+ config.model.d_model, config=mcts_config
245
+ )
246
+
247
+ # ----------------------------------------------------------------
248
  # EMA model (for inference, updated during training)
249
+ # ----------------------------------------------------------------
250
  self._ema_model: Optional[AamDiffusionModel] = None
251
  self._ema_decay = config.training.ema_decay
252
 
253
+ # Build a summary of active modules
254
+ active = []
255
+ if self.use_anchored_decoder:
256
+ active.append("AnchoredDecoder")
257
+ if self.use_evoformer:
258
+ active.append("Evoformer")
259
+ if self.use_dual_memory:
260
+ active.append("DualMemory")
261
+ if self.use_thinking_toggle:
262
+ active.append("ThinkingToggle")
263
+ if self.use_flow_matching:
264
+ active.append("FlowMatching")
265
+ if self.use_mcts:
266
+ active.append("MCTS")
267
+
268
+ module_str = ", ".join(active) if active else "legacy"
269
  logger.info(
270
+ "AamDiffusionModel v2.0 initialized: %s params, %s [modules: %s]",
271
  self._format_params(self.get_num_params()),
272
  config.model_name,
273
+ module_str,
274
  )
275
 
276
+ # ================================================================
277
+ # Forward pass (training)
278
+ # ================================================================
279
+
280
  def forward(
281
  self,
282
  token_ids: torch.Tensor,
 
292
  reasoning_ids: Optional[torch.Tensor] = None,
293
  reasoning_confidence: Optional[torch.Tensor] = None,
294
  source_trust: Optional[torch.Tensor] = None,
295
+ ) -> tuple[torch.Tensor, torch.Tensor]:
296
  """Forward pass for training.
297
 
298
  1. Get clean embeddings from token IDs
299
  2. Add noise at the given timestep
300
  3. Encode graph conditioning
301
  4. Predict noise via transformer
302
+ 5. [v2.0] Optionally apply Evoformer bidirectional feedback
303
+ 6. Return predicted noise (loss computed externally)
304
 
305
  Args:
306
  token_ids: Clean text token IDs, shape (batch, seq_len).
 
318
  source_trust: Source trust score.
319
 
320
  Returns:
321
+ Tuple of (predicted_noise, target_noise).
322
  """
323
  # Step 1: Get clean embeddings (x_0)
324
  x_0 = self.transformer.token_embedding(token_ids)
 
348
  graph_keys = graph_cond.get("keys")
349
  graph_values = graph_cond.get("values")
350
 
351
+ # [v2.0] Dual memory: enrich graph conditioning with memory
352
+ if self.use_dual_memory:
353
+ # Write current graph context to working memory
354
+ if graph_values is not None:
355
+ self.dual_memory.write(graph_values)
356
+ # Read memory-augmented context
357
+ if graph_keys is not None:
358
+ graph_keys = self.dual_memory.read(graph_keys)
359
+ if graph_values is not None:
360
+ graph_values = self.dual_memory.read(graph_values)
361
+
362
  # Step 4: Predict noise via transformer
363
  predicted = self.transformer(
364
  x_t=x_t,
 
367
  graph_values=graph_values,
368
  )
369
 
370
+ # [v2.0] Evoformer: bidirectional feedback between
371
+ # transformer output and graph conditioning
372
+ if self.use_evoformer:
373
+ # Level 2: Bidirectional token update
374
+ predicted = self.evoformer.bidirectional_token_update(predicted)
375
+
376
+ # Level 3: Decoder-predict feedback — graph output refines prediction
377
+ if graph_values is not None:
378
+ # Use mean-pooled graph values as the "decoder output"
379
+ graph_pooled = graph_values.mean(dim=1, keepdim=True).expand_as(
380
+ predicted
381
+ )
382
+ predicted = self.evoformer.apply_decoder_feedback(
383
+ predicted, graph_pooled
384
+ )
385
+
386
+ # Level 4: Prediction recycling — predicted output refines context
387
+ if self.use_anchored_decoder and hasattr(self, "output_head"):
388
+ # Get preliminary logits for prediction recycling
389
+ with torch.no_grad():
390
+ prelim_vectors = self.output_head.get_continuous_vectors(predicted)
391
+ predicted = self.evoformer.apply_prediction_recycling(
392
+ predicted, prelim_vectors
393
+ )
394
+
395
  return predicted, noise
396
 
397
+ # ================================================================
398
+ # Loss computation
399
+ # ================================================================
400
+
401
  def compute_loss(
402
  self,
403
  predicted: torch.Tensor,
 
491
  weight = weight.unsqueeze(-1).expand_as(loss)
492
  return loss * weight
493
 
494
+ # ================================================================
495
+ # Sampling / Inference
496
+ # ================================================================
497
+
498
  @torch.no_grad()
499
  def sample(
500
  self,
 
503
  method: str = "ddim",
504
  shape: Optional[tuple[int, ...]] = None,
505
  device: Optional[torch.device] = None,
506
+ temperature: float = 1.0,
507
  ) -> torch.Tensor:
508
  """Generate samples via iterative denoising.
509
 
510
+ This is the INFERENCE method. Supports multiple sampling
511
+ strategies in v2.0:
512
+
513
+ - "anchored": Uses ContinuousOutputHead for 2-3 step refinement
514
+ (fastest, starts from graph-conditioned prediction)
515
+ - "flow_matching": Uses FlowMatchingDecoder for velocity-based
516
+ sampling (2-3 steps)
517
+ - "ddpm": Legacy full DDPM sampling (many steps)
518
+ - "ddim": Legacy DDIM sampling (fewer steps, deterministic)
519
 
520
  Args:
521
  graph_cond: Graph conditioning dict from GraphConditioningEncoder.
522
  n_steps: Number of denoising steps. Uses config if None.
523
+ method: Sampling method 'anchored', 'flow_matching',
524
+ 'ddpm', or 'ddim'.
525
  shape: Shape of the output (batch, seq_len, d_model).
526
  device: Device to generate on.
527
+ temperature: Sampling temperature.
528
 
529
  Returns:
530
  Denoised embeddings of shape (batch, seq_len, d_model).
 
536
  if shape is None:
537
  shape = (1, self.config.model.max_seq_len, self.config.model.d_model)
538
 
 
 
 
539
  # Get graph conditioning
540
  graph_keys = graph_cond.get("keys")
541
  graph_values = graph_cond.get("values")
542
 
543
+ # [v2.0] Dual memory: augment graph conditioning with memory
544
+ if self.use_dual_memory:
545
+ if graph_values is not None:
546
+ self.dual_memory.write(graph_values)
547
+ if graph_keys is not None:
548
+ graph_keys = self.dual_memory.read(graph_keys)
549
+ if graph_values is not None:
550
+ graph_values = self.dual_memory.read(graph_values)
551
+
552
+ # ----------------------------------------------------------
553
+ # METHOD: Anchored Decoder (2-3 step refinement)
554
+ # ----------------------------------------------------------
555
+ if method == "anchored" and hasattr(self, "output_head"):
556
+ return self._sample_anchored(
557
+ graph_keys, graph_values, shape, device, n_steps, temperature
558
+ )
559
+
560
+ # ----------------------------------------------------------
561
+ # METHOD: Flow Matching Decoder
562
+ # ----------------------------------------------------------
563
+ if method == "flow_matching" and hasattr(self, "flow_matching_decoder"):
564
+ return self._sample_flow_matching(
565
+ graph_keys, graph_values, shape, device
566
+ )
567
+
568
+ # ----------------------------------------------------------
569
+ # METHOD: Legacy DDPM / DDIM
570
+ # ----------------------------------------------------------
571
+ return self._sample_legacy(
572
+ graph_keys, graph_values, shape, device, n_steps, method
573
+ )
574
+
575
+ def _sample_anchored(
576
+ self,
577
+ graph_keys: Optional[torch.Tensor],
578
+ graph_values: Optional[torch.Tensor],
579
+ shape: tuple[int, ...],
580
+ device: torch.device,
581
+ n_steps: int,
582
+ temperature: float,
583
+ ) -> torch.Tensor:
584
+ """Anchored decoding: start from transformer prediction, refine 2-3 steps.
585
+
586
+ Key insight: Instead of starting from noise and denoising for 50+
587
+ steps, we use the transformer's graph-conditioned prediction as an
588
+ anchor and refine it with the AnchoredDiffusionDecoder.
589
+ """
590
+ # Step 1: Get an initial prediction from the transformer
591
+ # Use a low-noise timestep so the transformer gives a meaningful
592
+ # starting point (t=0 would be ideal but we use a small t for
593
+ # stability with the noise scheduler)
594
+ batch_size = shape[0]
595
+ t_init = torch.full(
596
+ (batch_size,), 0, device=device, dtype=torch.long
597
+ )
598
+
599
+ # Start from a small amount of structured noise
600
+ x = torch.randn(shape, device=device) * 0.1
601
+
602
+ # Single transformer forward pass to get the initial anchor
603
+ initial_pred = self.transformer(
604
+ x_t=x, t=t_init,
605
+ graph_keys=graph_keys,
606
+ graph_values=graph_values,
607
+ )
608
+
609
+ # [v2.0] Evoformer feedback on initial prediction
610
+ if self.use_evoformer:
611
+ initial_pred = self.evoformer.bidirectional_token_update(initial_pred)
612
+ if graph_values is not None:
613
+ graph_pooled = graph_values.mean(dim=1, keepdim=True).expand_as(
614
+ initial_pred
615
+ )
616
+ initial_pred = self.evoformer.apply_decoder_feedback(
617
+ initial_pred, graph_pooled
618
+ )
619
+
620
+ # [v2.0] ThinkingToggle: determine refinement depth
621
+ refine_steps = n_steps
622
+ if self.use_thinking_toggle:
623
+ assessment = self.thinking_toggle(initial_pred)
624
+ # Scale refinement steps by depth multiplier
625
+ depth_mult = assessment.depth_multiplier.mean().item()
626
+ refine_steps = max(2, min(5, int(3 * depth_mult)))
627
+ logger.debug(
628
+ "ThinkingToggle: mode=%s, depth_mult=%.2f, refine_steps=%d",
629
+ assessment.mode.value,
630
+ depth_mult,
631
+ refine_steps,
632
+ )
633
+
634
+ # Step 2: Refine with Anchored Decoder
635
+ # The output_head internally does disambiguation + coherence
636
+ # + optional evoformer feedback in 2-3 steps
637
+ graph_context = graph_values.mean(dim=1) if graph_values is not None else None
638
+ logits, info = self.output_head(
639
+ initial_pred,
640
+ use_diffusion=True,
641
+ context=graph_context,
642
+ )
643
+
644
+ # The output_head gives us logits; we need to project back to
645
+ # embedding space for the final embeddings_to_tokens step.
646
+ # Use the token embedding matrix to convert logits → embeddings
647
+ logits_scaled = logits / temperature
648
+ probs = torch.softmax(logits_scaled, dim=-1)
649
+ embeddings = torch.matmul(
650
+ probs, self.transformer.token_embedding.weight
651
+ )
652
+
653
+ logger.debug(
654
+ "Anchored sampling: %d refine steps, delta=%.4f",
655
+ info.get("n_refine_steps", refine_steps),
656
+ info.get("refinement_delta", 0.0),
657
+ )
658
+
659
+ return embeddings
660
+
661
+ def _sample_flow_matching(
662
+ self,
663
+ graph_keys: Optional[torch.Tensor],
664
+ graph_values: Optional[torch.Tensor],
665
+ shape: tuple[int, ...],
666
+ device: torch.device,
667
+ ) -> torch.Tensor:
668
+ """Flow matching sampling: velocity-based 2-3 step refinement."""
669
+ batch_size = shape[0]
670
+
671
+ # Step 1: Get initial hidden state from transformer
672
+ t_init = torch.full(
673
+ (batch_size,), 0, device=device, dtype=torch.long
674
+ )
675
+ x = torch.randn(shape, device=device) * 0.1
676
+
677
+ initial_pred = self.transformer(
678
+ x_t=x, t=t_init,
679
+ graph_keys=graph_keys,
680
+ graph_values=graph_values,
681
+ )
682
+
683
+ # [v2.0] Evoformer feedback on initial prediction
684
+ if self.use_evoformer:
685
+ initial_pred = self.evoformer.bidirectional_token_update(initial_pred)
686
+ if graph_values is not None:
687
+ graph_pooled = graph_values.mean(dim=1, keepdim=True).expand_as(
688
+ initial_pred
689
+ )
690
+ initial_pred = self.evoformer.apply_decoder_feedback(
691
+ initial_pred, graph_pooled
692
+ )
693
+
694
+ # Step 2: Flow matching refinement
695
+ flow_output = self.flow_matching_decoder(initial_pred)
696
+
697
+ # Convert flow-matched logits back to embedding space
698
+ probs = torch.softmax(flow_output.refined_logits, dim=-1)
699
+ embeddings = torch.matmul(
700
+ probs, self.transformer.token_embedding.weight
701
+ )
702
+
703
+ logger.debug(
704
+ "Flow matching sampling: %d steps",
705
+ flow_output.num_steps,
706
+ )
707
+
708
+ return embeddings
709
+
710
+ def _sample_legacy(
711
+ self,
712
+ graph_keys: Optional[torch.Tensor],
713
+ graph_values: Optional[torch.Tensor],
714
+ shape: tuple[int, ...],
715
+ device: torch.device,
716
+ n_steps: int,
717
+ method: str,
718
+ ) -> torch.Tensor:
719
+ """Legacy DDPM/DDIM sampling (v1.0 compatible)."""
720
+ # Start from pure noise
721
+ x = torch.randn(shape, device=device)
722
+
723
  if method == "ddpm":
724
  # Full DDPM sampling
725
  for t in reversed(range(self.config.diffusion.n_timesteps)):
 
729
  graph_keys=graph_keys,
730
  graph_values=graph_values,
731
  )
732
+
733
+ # [v2.0] Evoformer feedback per step (expensive, only if enabled)
734
+ if self.use_evoformer:
735
+ predicted = self.evoformer.bidirectional_token_update(predicted)
736
+
737
  x = self.noise_scheduler.step_ddpm(predicted, x, t_tensor)
738
 
739
  elif method == "ddim":
 
749
  graph_keys=graph_keys,
750
  graph_values=graph_values,
751
  )
752
+
753
+ # [v2.0] Evoformer feedback per step
754
+ if self.use_evoformer:
755
+ predicted = self.evoformer.bidirectional_token_update(predicted)
756
+
757
  x = self.noise_scheduler.step_ddim(
758
  predicted, x, t, t_prev,
759
  eta=self.config.diffusion.eta_ddim,
760
  )
761
+ else:
762
+ raise ValueError(
763
+ f"Unknown sampling method: {method}. "
764
+ f"Use 'anchored', 'flow_matching', 'ddpm', or 'ddim'."
765
+ )
766
 
767
  return x
768
 
769
+ # ================================================================
770
+ # Embedding → Token conversion
771
+ # ================================================================
772
+
773
  def embeddings_to_tokens(
774
  self,
775
  embeddings: torch.Tensor,
776
  temperature: float = 1.0,
777
  top_k: int = 50,
778
+ graph_context: Optional[torch.Tensor] = None,
779
  ) -> torch.Tensor:
780
  """Convert continuous embeddings to discrete token IDs.
781
 
782
  This is the final step of generation — project embeddings
783
  to vocabulary logits and sample tokens.
784
 
785
+ v2.0: When ContinuousOutputHead is available, it uses the
786
+ anchored decoder for refined logits. Otherwise falls back
787
+ to the standard lm_head.
788
+
789
  Args:
790
  embeddings: Denoised embeddings of shape (batch, seq_len, d_model).
791
  temperature: Sampling temperature.
792
  top_k: Top-k sampling cutoff.
793
+ graph_context: Optional graph conditioning for anchored decoder.
794
 
795
  Returns:
796
  Token IDs of shape (batch, seq_len).
797
  """
798
+ if hasattr(self, "output_head"):
799
+ # v2.0: Use anchored decoder for refined logit prediction
800
+ logits, info = self.output_head(
801
+ embeddings, use_diffusion=True, context=graph_context
802
+ )
803
+ logits = logits / temperature
804
+ else:
805
+ # Legacy: simple linear projection
806
+ logits = self.lm_head(embeddings) / temperature
807
 
808
  # Top-k sampling
809
  if top_k > 0:
 
816
  -1, sampled_indices.unsqueeze(-1)
817
  ).squeeze(-1)
818
  else:
 
819
  token_ids = torch.argmax(logits, dim=-1)
820
 
821
  return token_ids
822
 
823
+ # ================================================================
824
+ # ThinkingToggle integration
825
+ # ================================================================
826
+
827
+ def assess_thinking(
828
+ self, hidden_states: torch.Tensor, force_mode=None
829
+ ) -> Optional[Any]:
830
+ """Assess whether the input needs deep thinking or quick response.
831
+
832
+ Only available when use_thinking_toggle=True.
833
+
834
+ Args:
835
+ hidden_states: Hidden states to assess, shape (batch, seq_len, d_model).
836
+ force_mode: Optional ThinkingMode to override the assessment.
837
+
838
+ Returns:
839
+ ThinkingAssessment if ThinkingToggle is enabled, else None.
840
+ """
841
+ if not self.use_thinking_toggle:
842
+ return None
843
+ return self.thinking_toggle(hidden_states, force_mode=force_mode)
844
+
845
+ # ================================================================
846
+ # MCTS integration
847
+ # ================================================================
848
+
849
+ def reason_with_mcts(
850
+ self,
851
+ hidden_states: torch.Tensor,
852
+ num_simulations: Optional[int] = None,
853
+ ) -> Optional[tuple[torch.Tensor, Dict[str, Any]]]:
854
+ """Run MCTS reasoning on hidden states.
855
+
856
+ Only available when use_mcts=True.
857
+
858
+ Args:
859
+ hidden_states: Hidden states to reason about.
860
+ num_simulations: Override number of MCTS simulations.
861
+
862
+ Returns:
863
+ Tuple of (action_probs, info_dict) if MCTS enabled, else None.
864
+ """
865
+ if not self.use_mcts:
866
+ return None
867
+ return self.mcts_reasoner(hidden_states, num_simulations=num_simulations)
868
+
869
+ # ================================================================
870
+ # Dual Memory management
871
+ # ================================================================
872
+
873
+ def memory_consolidate(self) -> None:
874
+ """Consolidate working memory into long-term memory.
875
+
876
+ Only available when use_dual_memory=True.
877
+ """
878
+ if self.use_dual_memory:
879
+ self.dual_memory.consolidate()
880
+
881
+ def memory_clear(self) -> None:
882
+ """Clear working memory.
883
+
884
+ Only available when use_dual_memory=True.
885
+ """
886
+ if self.use_dual_memory:
887
+ self.dual_memory.clear()
888
+
889
+ def memory_stats(self) -> Dict[str, object]:
890
+ """Get memory system statistics.
891
+
892
+ Returns:
893
+ Dict with memory stats, or empty dict if DualMemory disabled.
894
+ """
895
+ if self.use_dual_memory:
896
+ return self.dual_memory.get_stats()
897
+ return {}
898
+
899
+ # ================================================================
900
+ # Evoformer statistics
901
+ # ================================================================
902
+
903
+ def evoformer_stats(self) -> Dict[str, object]:
904
+ """Get Evoformer feedback statistics.
905
+
906
+ Returns:
907
+ Dict with evoformer stats, or empty dict if Evoformer disabled.
908
+ """
909
+ if self.use_evoformer:
910
+ return self.evoformer.get_stats()
911
+ return {}
912
+
913
+ # ================================================================
914
+ # Utility methods
915
+ # ================================================================
916
+
917
  def get_num_params(self) -> int:
918
  """Get total number of parameters."""
919
  return sum(p.numel() for p in self.parameters())
 
945
  def load(cls, path: str, device: str = "cpu") -> AamDiffusionModel:
946
  """Load model from checkpoint.
947
 
948
+ Supports both v2.0 and v1.0 checkpoints. Missing v2.0 config
949
+ fields are filled with defaults (disabled), ensuring backward
950
+ compatibility.
951
+
952
  Args:
953
  path: Checkpoint file path.
954
  device: Device to load to.
 
981
  logger.warning("Could not reconstruct config from checkpoint, using defaults")
982
  else:
983
  config = config_dict
984
+
985
+ # v2.0 config fields — attach from checkpoint dict if present
986
+ # so the model initializes optional modules correctly
987
+ for flag in [
988
+ "use_anchored_decoder", "use_evoformer", "use_dual_memory",
989
+ "use_thinking_toggle", "use_flow_matching", "use_mcts",
990
+ ]:
991
+ if flag not in config_dict:
992
+ # Old checkpoint — ensure the flag is False
993
+ if not hasattr(config, flag):
994
+ setattr(config, flag, False)
995
+
996
+ # Attach sub-configs if present in checkpoint
997
+ for sub_key in [
998
+ "anchored_decoder", "evoformer", "dual_memory",
999
+ "thinking_toggle", "flow_matching", "mcts",
1000
+ ]:
1001
+ if sub_key in config_dict and not hasattr(config, sub_key):
1002
+ setattr(config, sub_key, config_dict[sub_key])
1003
+
1004
  model = cls(config)
1005
+
1006
+ # Load state dict with partial matching for backward compatibility
1007
+ state_dict = checkpoint["model_state_dict"]
1008
+ model_state = model.state_dict()
1009
+
1010
+ # Separate keys that match vs. don't match
1011
+ matched = {k: v for k, v in state_dict.items() if k in model_state}
1012
+ missing = [k for k in model_state if k not in state_dict]
1013
+ unexpected = [k for k in state_dict if k not in model_state]
1014
+
1015
+ if missing:
1016
+ logger.info(
1017
+ "Loading checkpoint: %d keys missing (new v2.0 modules), "
1018
+ "will use random init for those.",
1019
+ len(missing),
1020
+ )
1021
+ if unexpected:
1022
+ logger.info(
1023
+ "Loading checkpoint: %d unexpected keys (legacy modules).",
1024
+ len(unexpected),
1025
+ )
1026
+
1027
+ model.load_state_dict(matched, strict=False)
1028
  model.to(device)
1029
  logger.info("Model loaded from %s", path)
1030
  return model