Wolfvin commited on
Commit
3a8397a
·
verified ·
1 Parent(s): 9bdac3f

Upload diffusion_llm/inference/generator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion_llm/inference/generator.py +546 -75
diffusion_llm/inference/generator.py CHANGED
@@ -1,11 +1,27 @@
1
  """
2
- AAM Diffusion LLM — Inference Generator
3
 
4
  Generates natural language narratives from graph conditioning
5
  using the trained diffusion model.
6
 
7
- The generation process:
 
 
 
 
 
 
 
 
8
  1. Encode graph conditioning (evidence, anomalies, reasoning)
 
 
 
 
 
 
 
 
9
  2. Start from pure noise in the latent space
10
  3. Iteratively denoise for N steps
11
  4. Convert denoised embeddings to token IDs
@@ -13,8 +29,9 @@ The generation process:
13
 
14
  Analogi: Seperti Jin Soun akhirnya "berbicara" — dari
15
  pikiran yang kabur (noise) menjadi kata-kata yang jelas
16
- (denoised narrative). Setiap langkah denoising = satu
17
- langkah lebih dekat ke koherensi.
 
18
  """
19
 
20
  from __future__ import annotations
@@ -22,7 +39,7 @@ from __future__ import annotations
22
  import logging
23
  import time
24
  from dataclasses import dataclass, field
25
- from typing import Optional
26
 
27
  import torch
28
 
@@ -40,6 +57,7 @@ class GenerationResult:
40
  Contains the generated narrative plus metadata about
41
  how it was generated, for traceability.
42
  """
 
43
  narrative: str
44
  """Generated narrative text."""
45
 
@@ -64,9 +82,25 @@ class GenerationResult:
64
  language: str = "id"
65
  """Output language."""
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def to_dict(self) -> dict:
68
  """Serialize to dictionary."""
69
- return {
70
  "narrative": self.narrative,
71
  "n_diffusion_steps": self.n_diffusion_steps,
72
  "generation_time_s": round(self.generation_time_s, 3),
@@ -74,16 +108,32 @@ class GenerationResult:
74
  "evidence_used": self.evidence_used,
75
  "confidence": round(self.confidence, 3),
76
  "language": self.language,
 
77
  }
 
 
 
 
 
 
 
 
78
 
79
 
80
  class AamGenerator:
81
- """Generate narratives from graph conditioning using the trained model.
82
 
83
  This is the main inference interface. It takes graph-structured
84
  data (from the RSVS Knowledge Graph) and produces natural
85
  language narratives through the diffusion denoising process.
86
 
 
 
 
 
 
 
 
87
  Usage:
88
  # Load model and tokenizer
89
  config = AamDiffusionConfig.from_json("config.json")
@@ -93,12 +143,21 @@ class AamGenerator:
93
  # Create generator
94
  generator = AamGenerator(model, tokenizer, config)
95
 
96
- # Generate narrative
97
  result = generator.generate(
98
  trigger="Siapa yang mencuri Snow Plum Pill?",
99
  evidence_nodes=["hefei", "diancang", "ju_jangmok"],
100
  anomalies=["no external pill consumption"],
101
  reasoning_steps=["Diancang pair was in Hefei before theft"],
 
 
 
 
 
 
 
 
 
102
  )
103
  print(result.narrative)
104
 
@@ -125,6 +184,25 @@ class AamGenerator:
125
  # Set model to eval mode
126
  self.model.eval()
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  @torch.no_grad()
129
  def generate(
130
  self,
@@ -139,14 +217,19 @@ class AamGenerator:
139
  temperature: Optional[float] = None,
140
  language: Optional[str] = None,
141
  max_sentences: Optional[int] = None,
 
 
 
142
  ) -> GenerationResult:
143
  """Generate a narrative from graph conditioning.
144
 
145
  This is the main generation method. It:
146
  1. Tokenizes the graph conditioning data
147
  2. Encodes it through the graph encoder
148
- 3. Starts from noise and iteratively denoises
149
- 4. Converts the result to text
 
 
150
 
151
  Args:
152
  trigger: The trigger question or topic.
@@ -160,6 +243,12 @@ class AamGenerator:
160
  temperature: Override sampling temperature.
161
  language: Override output language.
162
  max_sentences: Maximum sentences in output.
 
 
 
 
 
 
163
 
164
  Returns:
165
  GenerationResult with the narrative and metadata.
@@ -172,70 +261,49 @@ class AamGenerator:
172
  language = language or self.inference_config.language
173
  max_sentences = max_sentences or self.inference_config.max_output_sentences
174
 
175
- # --- Step 1: Tokenize graph conditioning ---
176
- evidence_ids_tensor = None
177
- evidence_conf_tensor = None
178
- anomaly_ids_tensor = None
179
- anomaly_conf_tensor = None
180
- reasoning_ids_tensor = None
181
- reasoning_conf_tensor = None
182
-
183
- if evidence_nodes:
184
- evidence_ids_list = []
185
- evidence_conf_list = []
186
- for node in evidence_nodes[:self.config.graph_encoder.max_evidence_nodes]:
187
- ids = self.tokenizer.encode(node, add_special=False)
188
- ids = self.tokenizer.pad_sequence(ids, 32)
189
- evidence_ids_list.append(ids)
190
- conf = (confidence_map or {}).get(node, 0.7)
191
- evidence_conf_list.append(conf)
192
-
193
- while len(evidence_ids_list) < self.config.graph_encoder.max_evidence_nodes:
194
- evidence_ids_list.append([0] * 32)
195
- evidence_conf_list.append(0.0)
196
-
197
- evidence_ids_tensor = torch.tensor(
198
- [evidence_ids_list], dtype=torch.long, device=self.device
199
- )
200
- evidence_conf_tensor = torch.tensor(
201
- [evidence_conf_list], dtype=torch.float32, device=self.device
202
  )
 
203
 
204
- if anomalies:
205
- anomaly_ids_list = []
206
- for anom in anomalies[:self.config.graph_encoder.max_anomalies]:
207
- ids = self.tokenizer.encode(anom, add_special=False)
208
- ids = self.tokenizer.pad_sequence(ids, 32)
209
- anomaly_ids_list.append(ids)
210
-
211
- while len(anomaly_ids_list) < self.config.graph_encoder.max_anomalies:
212
- anomaly_ids_list.append([0] * 32)
213
-
214
- anomaly_ids_tensor = torch.tensor(
215
- [anomaly_ids_list], dtype=torch.long, device=self.device
216
- )
217
- anomaly_conf_tensor = torch.full(
218
- (1, self.config.graph_encoder.max_anomalies),
219
- 0.6, dtype=torch.float32, device=self.device,
220
  )
 
221
 
222
- if reasoning_steps:
223
- reasoning_ids_list = []
224
- for step in reasoning_steps[:self.config.graph_encoder.max_reasoning_steps]:
225
- ids = self.tokenizer.encode(step, add_special=False)
226
- ids = self.tokenizer.pad_sequence(ids, 32)
227
- reasoning_ids_list.append(ids)
228
-
229
- while len(reasoning_ids_list) < self.config.graph_encoder.max_reasoning_steps:
230
- reasoning_ids_list.append([0] * 32)
231
-
232
- reasoning_ids_tensor = torch.tensor(
233
- [reasoning_ids_list], dtype=torch.long, device=self.device
234
- )
235
- reasoning_conf_tensor = torch.full(
236
- (1, self.config.graph_encoder.max_reasoning_steps),
237
- 0.7, dtype=torch.float32, device=self.device,
238
- )
 
239
 
240
  source_trust_tensor = torch.tensor(
241
  [source_trust], dtype=torch.float32, device=self.device
@@ -249,10 +317,65 @@ class AamGenerator:
249
  anomaly_confidence=anomaly_conf_tensor,
250
  reasoning_ids=reasoning_ids_tensor,
251
  reasoning_confidence=reasoning_conf_tensor,
 
 
252
  source_trust=source_trust_tensor,
253
  )
254
 
255
- # --- Step 3: Generate via diffusion denoising ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  shape = (
257
  1,
258
  self.config.model.max_seq_len,
@@ -262,18 +385,27 @@ class AamGenerator:
262
  denoised = self.model.sample(
263
  graph_cond=graph_cond,
264
  n_steps=n_steps,
265
- method=self.config.diffusion.sampling_method,
266
  shape=shape,
267
  device=self.device,
 
268
  )
269
 
270
- # --- Step 4: Convert to tokens ---
 
 
 
 
 
 
271
  token_ids = self.model.embeddings_to_tokens(
272
- denoised, temperature=temperature,
 
273
  top_k=self.inference_config.top_k,
 
274
  )
275
 
276
- # --- Step 5: Detokenize ---
277
  token_list = token_ids[0].cpu().tolist()
278
  narrative = self.tokenizer.decode(token_list, skip_special=True)
279
 
@@ -290,6 +422,13 @@ class AamGenerator:
290
  if confidence_map:
291
  avg_confidence = sum(confidence_map.values()) / len(confidence_map)
292
 
 
 
 
 
 
 
 
293
  return GenerationResult(
294
  narrative=narrative,
295
  token_ids=token_list,
@@ -299,8 +438,262 @@ class AamGenerator:
299
  evidence_used=evidence_nodes or [],
300
  confidence=avg_confidence,
301
  language=language,
 
 
 
 
 
302
  )
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  def generate_batch(
305
  self,
306
  triggers: list[str],
@@ -331,3 +724,81 @@ class AamGenerator:
331
  )
332
  results.append(result)
333
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ AAM Diffusion LLM — Inference Generator (v2.0)
3
 
4
  Generates natural language narratives from graph conditioning
5
  using the trained diffusion model.
6
 
7
+ v2.0 Upgrades:
8
+ - ThinkingToggle for adaptive inference (thinking vs non-thinking)
9
+ - Anchored decoding method (2-3 steps instead of 50)
10
+ - Flow matching method (velocity-based 2-3 step sampling)
11
+ - MCTS integration for complex reasoning tasks
12
+ - DualMemorySystem for long narrative generation
13
+ - Full backward compatibility with v1.0 generation
14
+
15
+ The generation process (v2.0 Anchored):
16
  1. Encode graph conditioning (evidence, anomalies, reasoning)
17
+ 2. [Optional] ThinkingToggle assesses complexity
18
+ 3. [Optional] MCTS explores narrative arrangements for complex inputs
19
+ 4. Generate via anchored decoding (2-3 refinement steps)
20
+ 5. Convert denoised embeddings to token IDs
21
+ 6. Detokenize to natural language text
22
+
23
+ The generation process (Legacy DDPM/DDIM):
24
+ 1. Encode graph conditioning
25
  2. Start from pure noise in the latent space
26
  3. Iteratively denoise for N steps
27
  4. Convert denoised embeddings to token IDs
 
29
 
30
  Analogi: Seperti Jin Soun akhirnya "berbicara" — dari
31
  pikiran yang kabur (noise) menjadi kata-kata yang jelas
32
+ (denoised narrative). Di v2.0, Jin Soun sekarang bisa
33
+ memilih: berbicara cepat untuk hal sederhana (non-thinking),
34
+ atau berpikir dalam untuk masalah rumit (thinking + MCTS).
35
  """
36
 
37
  from __future__ import annotations
 
39
  import logging
40
  import time
41
  from dataclasses import dataclass, field
42
+ from typing import Any, Dict, Optional
43
 
44
  import torch
45
 
 
57
  Contains the generated narrative plus metadata about
58
  how it was generated, for traceability.
59
  """
60
+
61
  narrative: str
62
  """Generated narrative text."""
63
 
 
82
  language: str = "id"
83
  """Output language."""
84
 
85
+ # v2.0 metadata
86
+ sampling_method: str = "ddim"
87
+ """Sampling method used ('anchored', 'flow_matching', 'ddpm', 'ddim')."""
88
+
89
+ thinking_mode: str = ""
90
+ """ThinkingToggle mode: 'thinking', 'non_thinking', or '' if disabled."""
91
+
92
+ complexity_score: float = 0.0
93
+ """Complexity score from ThinkingToggle (0.0 if disabled)."""
94
+
95
+ mcts_used: bool = False
96
+ """Whether MCTS reasoning was used."""
97
+
98
+ memory_stats: Dict[str, object] = field(default_factory=dict)
99
+ """DualMemory statistics at generation time."""
100
+
101
  def to_dict(self) -> dict:
102
  """Serialize to dictionary."""
103
+ result = {
104
  "narrative": self.narrative,
105
  "n_diffusion_steps": self.n_diffusion_steps,
106
  "generation_time_s": round(self.generation_time_s, 3),
 
108
  "evidence_used": self.evidence_used,
109
  "confidence": round(self.confidence, 3),
110
  "language": self.language,
111
+ "sampling_method": self.sampling_method,
112
  }
113
+ if self.thinking_mode:
114
+ result["thinking_mode"] = self.thinking_mode
115
+ result["complexity_score"] = round(self.complexity_score, 3)
116
+ if self.mcts_used:
117
+ result["mcts_used"] = True
118
+ if self.memory_stats:
119
+ result["memory_stats"] = self.memory_stats
120
+ return result
121
 
122
 
123
  class AamGenerator:
124
+ """Generate narratives from graph conditioning using the trained model (v2.0).
125
 
126
  This is the main inference interface. It takes graph-structured
127
  data (from the RSVS Knowledge Graph) and produces natural
128
  language narratives through the diffusion denoising process.
129
 
130
+ v2.0 features:
131
+ - Adaptive compute via ThinkingToggle
132
+ - Fast anchored decoding (2-3 steps)
133
+ - Flow matching decoding
134
+ - MCTS for complex reasoning
135
+ - Dual memory for long narratives
136
+
137
  Usage:
138
  # Load model and tokenizer
139
  config = AamDiffusionConfig.from_json("config.json")
 
143
  # Create generator
144
  generator = AamGenerator(model, tokenizer, config)
145
 
146
+ # Generate narrative (v2.0 anchored decoding)
147
  result = generator.generate(
148
  trigger="Siapa yang mencuri Snow Plum Pill?",
149
  evidence_nodes=["hefei", "diancang", "ju_jangmok"],
150
  anomalies=["no external pill consumption"],
151
  reasoning_steps=["Diancang pair was in Hefei before theft"],
152
+ method="anchored",
153
+ )
154
+ print(result.narrative)
155
+
156
+ # Generate narrative (legacy DDIM)
157
+ result = generator.generate(
158
+ trigger="Summary of events",
159
+ evidence_nodes=["event_a", "event_b"],
160
+ method="ddim",
161
  )
162
  print(result.narrative)
163
 
 
184
  # Set model to eval mode
185
  self.model.eval()
186
 
187
+ # Feature detection
188
+ self._has_anchored_decoder = hasattr(model, "output_head")
189
+ self._has_thinking_toggle = hasattr(model, "thinking_toggle")
190
+ self._has_flow_matching = hasattr(model, "flow_matching_decoder")
191
+ self._has_mcts = hasattr(model, "mcts_reasoner")
192
+ self._has_dual_memory = hasattr(model, "dual_memory")
193
+ self._has_evoformer = hasattr(model, "evoformer")
194
+
195
+ logger.info(
196
+ "AamGenerator v2.0 initialized. Features: anchored=%s, thinking=%s, "
197
+ "flow=%s, mcts=%s, memory=%s, evoformer=%s",
198
+ self._has_anchored_decoder,
199
+ self._has_thinking_toggle,
200
+ self._has_flow_matching,
201
+ self._has_mcts,
202
+ self._has_dual_memory,
203
+ self._has_evoformer,
204
+ )
205
+
206
  @torch.no_grad()
207
  def generate(
208
  self,
 
217
  temperature: Optional[float] = None,
218
  language: Optional[str] = None,
219
  max_sentences: Optional[int] = None,
220
+ method: Optional[str] = None,
221
+ use_mcts: Optional[bool] = None,
222
+ force_thinking_mode: Optional[str] = None,
223
  ) -> GenerationResult:
224
  """Generate a narrative from graph conditioning.
225
 
226
  This is the main generation method. It:
227
  1. Tokenizes the graph conditioning data
228
  2. Encodes it through the graph encoder
229
+ 3. [v2.0] Optionally assesses thinking complexity
230
+ 4. [v2.0] Optionally runs MCTS for complex reasoning
231
+ 5. Generates via the selected sampling method
232
+ 6. Converts the result to text
233
 
234
  Args:
235
  trigger: The trigger question or topic.
 
243
  temperature: Override sampling temperature.
244
  language: Override output language.
245
  max_sentences: Maximum sentences in output.
246
+ method: Sampling method — 'anchored', 'flow_matching',
247
+ 'ddpm', 'ddim', or None (uses config default).
248
+ use_mcts: Override whether to use MCTS. None = auto-decide
249
+ based on ThinkingToggle assessment.
250
+ force_thinking_mode: Force thinking mode ('thinking' or
251
+ 'non_thinking'). None = auto-decide.
252
 
253
  Returns:
254
  GenerationResult with the narrative and metadata.
 
261
  language = language or self.inference_config.language
262
  max_sentences = max_sentences or self.inference_config.max_output_sentences
263
 
264
+ # Determine sampling method
265
+ if method is None:
266
+ # Default to anchored if available, else use config
267
+ if self._has_anchored_decoder:
268
+ method = "anchored"
269
+ else:
270
+ method = self.config.diffusion.sampling_method
271
+
272
+ # Validate method availability
273
+ if method == "anchored" and not self._has_anchored_decoder:
274
+ logger.warning(
275
+ "Anchored decoding requested but ContinuousOutputHead not "
276
+ "available. Falling back to '%s'.",
277
+ self.config.diffusion.sampling_method,
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  )
279
+ method = self.config.diffusion.sampling_method
280
 
281
+ if method == "flow_matching" and not self._has_flow_matching:
282
+ logger.warning(
283
+ "Flow matching requested but FlowMatchingDecoder not "
284
+ "available. Falling back to '%s'.",
285
+ self.config.diffusion.sampling_method,
 
 
 
 
 
 
 
 
 
 
 
286
  )
287
+ method = self.config.diffusion.sampling_method
288
 
289
+ # --- Step 1: Tokenize graph conditioning ---
290
+ (
291
+ evidence_ids_tensor,
292
+ evidence_conf_tensor,
293
+ anomaly_ids_tensor,
294
+ anomaly_conf_tensor,
295
+ reasoning_ids_tensor,
296
+ reasoning_conf_tensor,
297
+ composition_ids_tensor,
298
+ composition_conf_tensor,
299
+ ) = self._tokenize_graph_conditioning(
300
+ evidence_nodes=evidence_nodes,
301
+ compositions=compositions,
302
+ confidence_map=confidence_map,
303
+ anomalies=anomalies,
304
+ reasoning_steps=reasoning_steps,
305
+ source_trust=source_trust,
306
+ )
307
 
308
  source_trust_tensor = torch.tensor(
309
  [source_trust], dtype=torch.float32, device=self.device
 
317
  anomaly_confidence=anomaly_conf_tensor,
318
  reasoning_ids=reasoning_ids_tensor,
319
  reasoning_confidence=reasoning_conf_tensor,
320
+ composition_ids=composition_ids_tensor,
321
+ composition_confidence=composition_conf_tensor,
322
  source_trust=source_trust_tensor,
323
  )
324
 
325
+ # --- Step 3: ThinkingToggle assessment ---
326
+ thinking_mode_str = ""
327
+ complexity_score = 0.0
328
+ assessment = None
329
+
330
+ if self._has_thinking_toggle:
331
+ assessment = self._assess_complexity(
332
+ graph_cond, force_thinking_mode=force_thinking_mode
333
+ )
334
+ if assessment is not None:
335
+ thinking_mode_str = assessment.mode.value
336
+ complexity_score = (
337
+ assessment.complexity_score.mean().item()
338
+ if assessment.complexity_score.numel() > 0
339
+ else 0.0
340
+ )
341
+
342
+ # Adaptive step count based on thinking assessment
343
+ if method == "anchored":
344
+ depth_mult = assessment.depth_multiplier.mean().item()
345
+ n_steps = max(2, min(5, int(3 * depth_mult)))
346
+ elif method in ("ddpm", "ddim"):
347
+ depth_mult = assessment.depth_multiplier.mean().item()
348
+ n_steps = max(
349
+ 10,
350
+ int(self.inference_config.n_steps * depth_mult),
351
+ )
352
+
353
+ logger.debug(
354
+ "ThinkingToggle: mode=%s, complexity=%.3f, "
355
+ "depth_mult=%.2f, n_steps=%d",
356
+ thinking_mode_str,
357
+ complexity_score,
358
+ assessment.depth_multiplier.mean().item(),
359
+ n_steps,
360
+ )
361
+
362
+ # --- Step 4: MCTS reasoning (for complex inputs) ---
363
+ mcts_used = False
364
+ mcts_info: Dict[str, Any] = {}
365
+
366
+ should_use_mcts = self._should_use_mcts(
367
+ use_mcts=use_mcts,
368
+ assessment=assessment,
369
+ method=method,
370
+ )
371
+
372
+ if should_use_mcts:
373
+ mcts_result = self._run_mcts_reasoning(graph_cond)
374
+ if mcts_result is not None:
375
+ mcts_used = True
376
+ mcts_info = mcts_result
377
+
378
+ # --- Step 5: Generate via diffusion denoising ---
379
  shape = (
380
  1,
381
  self.config.model.max_seq_len,
 
385
  denoised = self.model.sample(
386
  graph_cond=graph_cond,
387
  n_steps=n_steps,
388
+ method=method,
389
  shape=shape,
390
  device=self.device,
391
+ temperature=temperature,
392
  )
393
 
394
+ # --- Step 6: Convert to tokens ---
395
+ # Extract graph context for anchored decoder
396
+ graph_values = graph_cond.get("values")
397
+ graph_context = None
398
+ if graph_values is not None:
399
+ graph_context = graph_values.mean(dim=1)
400
+
401
  token_ids = self.model.embeddings_to_tokens(
402
+ denoised,
403
+ temperature=temperature,
404
  top_k=self.inference_config.top_k,
405
+ graph_context=graph_context,
406
  )
407
 
408
+ # --- Step 7: Detokenize ---
409
  token_list = token_ids[0].cpu().tolist()
410
  narrative = self.tokenizer.decode(token_list, skip_special=True)
411
 
 
422
  if confidence_map:
423
  avg_confidence = sum(confidence_map.values()) / len(confidence_map)
424
 
425
+ # Collect memory stats
426
+ mem_stats = self.model.memory_stats() if self._has_dual_memory else {}
427
+
428
+ # Consolidate memory for future generations
429
+ if self._has_dual_memory:
430
+ self.model.memory_consolidate()
431
+
432
  return GenerationResult(
433
  narrative=narrative,
434
  token_ids=token_list,
 
438
  evidence_used=evidence_nodes or [],
439
  confidence=avg_confidence,
440
  language=language,
441
+ sampling_method=method,
442
+ thinking_mode=thinking_mode_str,
443
+ complexity_score=complexity_score,
444
+ mcts_used=mcts_used,
445
+ memory_stats=mem_stats,
446
  )
447
 
448
+ # ================================================================
449
+ # Internal helpers
450
+ # ================================================================
451
+
452
+ def _tokenize_graph_conditioning(
453
+ self,
454
+ evidence_nodes: Optional[list[str]] = None,
455
+ compositions: Optional[list[str]] = None,
456
+ confidence_map: Optional[dict[str, float]] = None,
457
+ anomalies: Optional[list[str]] = None,
458
+ reasoning_steps: Optional[list[str]] = None,
459
+ source_trust: float = 1.0,
460
+ ) -> tuple:
461
+ """Tokenize all graph conditioning data into tensors.
462
+
463
+ Returns:
464
+ Tuple of (evidence_ids, evidence_conf, anomaly_ids,
465
+ anomaly_conf, reasoning_ids, reasoning_conf,
466
+ composition_ids, composition_conf) tensors.
467
+ """
468
+ evidence_ids_tensor = None
469
+ evidence_conf_tensor = None
470
+ anomaly_ids_tensor = None
471
+ anomaly_conf_tensor = None
472
+ reasoning_ids_tensor = None
473
+ reasoning_conf_tensor = None
474
+ composition_ids_tensor = None
475
+ composition_conf_tensor = None
476
+
477
+ max_evidence = self.config.graph_encoder.max_evidence_nodes
478
+ max_anomalies = self.config.graph_encoder.max_anomalies
479
+ max_reasoning = self.config.graph_encoder.max_reasoning_steps
480
+ max_compositions = self.config.graph_encoder.max_compositions
481
+ node_len = 32
482
+
483
+ # Evidence nodes
484
+ if evidence_nodes:
485
+ evidence_ids_list = []
486
+ evidence_conf_list = []
487
+ for node in evidence_nodes[:max_evidence]:
488
+ ids = self.tokenizer.encode(node, add_special=False)
489
+ ids = self.tokenizer.pad_sequence(ids, node_len)
490
+ evidence_ids_list.append(ids)
491
+ conf = (confidence_map or {}).get(node, 0.7)
492
+ evidence_conf_list.append(conf)
493
+
494
+ while len(evidence_ids_list) < max_evidence:
495
+ evidence_ids_list.append([0] * node_len)
496
+ evidence_conf_list.append(0.0)
497
+
498
+ evidence_ids_tensor = torch.tensor(
499
+ [evidence_ids_list], dtype=torch.long, device=self.device
500
+ )
501
+ evidence_conf_tensor = torch.tensor(
502
+ [evidence_conf_list], dtype=torch.float32, device=self.device
503
+ )
504
+
505
+ # Compositions
506
+ if compositions:
507
+ composition_ids_list = []
508
+ composition_conf_list = []
509
+ for comp in compositions[:max_compositions]:
510
+ ids = self.tokenizer.encode(comp, add_special=False)
511
+ ids = self.tokenizer.pad_sequence(ids, node_len)
512
+ composition_ids_list.append(ids)
513
+ composition_conf_list.append(0.8)
514
+
515
+ while len(composition_ids_list) < max_compositions:
516
+ composition_ids_list.append([0] * node_len)
517
+ composition_conf_list.append(0.0)
518
+
519
+ composition_ids_tensor = torch.tensor(
520
+ [composition_ids_list], dtype=torch.long, device=self.device
521
+ )
522
+ composition_conf_tensor = torch.tensor(
523
+ [composition_conf_list], dtype=torch.float32, device=self.device
524
+ )
525
+
526
+ # Anomalies
527
+ if anomalies:
528
+ anomaly_ids_list = []
529
+ for anom in anomalies[:max_anomalies]:
530
+ ids = self.tokenizer.encode(anom, add_special=False)
531
+ ids = self.tokenizer.pad_sequence(ids, node_len)
532
+ anomaly_ids_list.append(ids)
533
+
534
+ while len(anomaly_ids_list) < max_anomalies:
535
+ anomaly_ids_list.append([0] * node_len)
536
+
537
+ anomaly_ids_tensor = torch.tensor(
538
+ [anomaly_ids_list], dtype=torch.long, device=self.device
539
+ )
540
+ anomaly_conf_tensor = torch.full(
541
+ (1, max_anomalies),
542
+ 0.6, dtype=torch.float32, device=self.device,
543
+ )
544
+
545
+ # Reasoning steps
546
+ if reasoning_steps:
547
+ reasoning_ids_list = []
548
+ for step in reasoning_steps[:max_reasoning]:
549
+ ids = self.tokenizer.encode(step, add_special=False)
550
+ ids = self.tokenizer.pad_sequence(ids, node_len)
551
+ reasoning_ids_list.append(ids)
552
+
553
+ while len(reasoning_ids_list) < max_reasoning:
554
+ reasoning_ids_list.append([0] * node_len)
555
+
556
+ reasoning_ids_tensor = torch.tensor(
557
+ [reasoning_ids_list], dtype=torch.long, device=self.device
558
+ )
559
+ reasoning_conf_tensor = torch.full(
560
+ (1, max_reasoning),
561
+ 0.7, dtype=torch.float32, device=self.device,
562
+ )
563
+
564
+ return (
565
+ evidence_ids_tensor,
566
+ evidence_conf_tensor,
567
+ anomaly_ids_tensor,
568
+ anomaly_conf_tensor,
569
+ reasoning_ids_tensor,
570
+ reasoning_conf_tensor,
571
+ composition_ids_tensor,
572
+ composition_conf_tensor,
573
+ )
574
+
575
+ def _assess_complexity(
576
+ self,
577
+ graph_cond: dict[str, torch.Tensor],
578
+ force_thinking_mode: Optional[str] = None,
579
+ ) -> Optional[Any]:
580
+ """Use ThinkingToggle to assess the complexity of the input.
581
+
582
+ Args:
583
+ graph_cond: Graph conditioning dict from encoder.
584
+ force_thinking_mode: Force 'thinking' or 'non_thinking'.
585
+
586
+ Returns:
587
+ ThinkingAssessment or None if not available.
588
+ """
589
+ if not self._has_thinking_toggle:
590
+ return None
591
+
592
+ from diffusion_llm.model.thinking_toggle import ThinkingMode
593
+
594
+ # Build a hidden-state-like tensor from graph conditioning
595
+ # for the ThinkingToggle to assess
596
+ graph_values = graph_cond.get("values")
597
+ if graph_values is None:
598
+ return None
599
+
600
+ # Reshape to (batch, seq, d_model) if needed
601
+ if graph_values.dim() == 2:
602
+ graph_values = graph_values.unsqueeze(0)
603
+
604
+ force_mode = None
605
+ if force_thinking_mode == "thinking":
606
+ force_mode = ThinkingMode.THINKING
607
+ elif force_thinking_mode == "non_thinking":
608
+ force_mode = ThinkingMode.NON_THINKING
609
+
610
+ try:
611
+ assessment = self.model.thinking_toggle(
612
+ graph_values, force_mode=force_mode
613
+ )
614
+ return assessment
615
+ except Exception as e:
616
+ logger.warning("ThinkingToggle assessment failed: %s", e)
617
+ return None
618
+
619
+ def _should_use_mcts(
620
+ self,
621
+ use_mcts: Optional[bool],
622
+ assessment: Optional[Any],
623
+ method: str,
624
+ ) -> bool:
625
+ """Determine whether MCTS should be used.
626
+
627
+ Logic:
628
+ - If use_mcts is explicitly True/False, use that.
629
+ - If use_mcts is None (auto), use MCTS when:
630
+ - ThinkingToggle is in THINKING mode, AND
631
+ - The task type is REASONING or ANOMALY_RESOLUTION, AND
632
+ - MCTS module is available
633
+ """
634
+ if not self._has_mcts:
635
+ return False
636
+
637
+ if use_mcts is not None:
638
+ return use_mcts
639
+
640
+ # Auto-decide based on ThinkingToggle
641
+ if assessment is None:
642
+ return False
643
+
644
+ from diffusion_llm.model.thinking_toggle import (
645
+ ThinkingMode,
646
+ TaskType,
647
+ )
648
+
649
+ if assessment.mode != ThinkingMode.THINKING:
650
+ return False
651
+
652
+ # Only use MCTS for reasoning-heavy task types
653
+ if assessment.dominant_task in (
654
+ TaskType.REASONING,
655
+ TaskType.ANOMALY_RESOLUTION,
656
+ ):
657
+ return True
658
+
659
+ return False
660
+
661
+ def _run_mcts_reasoning(
662
+ self,
663
+ graph_cond: dict[str, torch.Tensor],
664
+ ) -> Optional[Dict[str, Any]]:
665
+ """Run MCTS reasoning on graph conditioning.
666
+
667
+ Args:
668
+ graph_cond: Graph conditioning dict from encoder.
669
+
670
+ Returns:
671
+ Dict with MCTS info, or None if MCTS failed.
672
+ """
673
+ graph_values = graph_cond.get("values")
674
+ if graph_values is None:
675
+ return None
676
+
677
+ # Reshape for MCTS input
678
+ if graph_values.dim() == 2:
679
+ graph_values = graph_values.unsqueeze(0)
680
+
681
+ try:
682
+ action_probs, info = self.model.mcts_reasoner(graph_values)
683
+ return {
684
+ "action_probs_mean": action_probs.mean().item(),
685
+ "total_simulations": info.get("total_simulations", 0),
686
+ "root_value": info.get("root_value", 0.0),
687
+ "entropy": info.get("entropy", 0.0),
688
+ }
689
+ except Exception as e:
690
+ logger.warning("MCTS reasoning failed: %s", e)
691
+ return None
692
+
693
+ # ================================================================
694
+ # Batch generation
695
+ # ================================================================
696
+
697
  def generate_batch(
698
  self,
699
  triggers: list[str],
 
724
  )
725
  results.append(result)
726
  return results
727
+
728
+ # ================================================================
729
+ # Memory management
730
+ # ================================================================
731
+
732
+ def clear_memory(self) -> None:
733
+ """Clear the model's dual memory system.
734
+
735
+ Useful between independent generation sessions.
736
+ """
737
+ if self._has_dual_memory:
738
+ self.model.memory_clear()
739
+ logger.info("Dual memory cleared.")
740
+
741
+ def get_memory_stats(self) -> Dict[str, object]:
742
+ """Get current memory statistics.
743
+
744
+ Returns:
745
+ Dict with memory stats, or empty dict if memory disabled.
746
+ """
747
+ if self._has_dual_memory:
748
+ return self.model.memory_stats()
749
+ return {}
750
+
751
+ # ================================================================
752
+ # Convenience methods
753
+ # ================================================================
754
+
755
+ def generate_fast(
756
+ self,
757
+ trigger: str = "",
758
+ **kwargs,
759
+ ) -> GenerationResult:
760
+ """Generate with fastest settings (non-thinking, anchored, minimal steps).
761
+
762
+ Convenience wrapper for quick generation.
763
+
764
+ Args:
765
+ trigger: The trigger question or topic.
766
+ **kwargs: Additional arguments passed to generate().
767
+
768
+ Returns:
769
+ GenerationResult with the narrative.
770
+ """
771
+ return self.generate(
772
+ trigger=trigger,
773
+ method="anchored",
774
+ force_thinking_mode="non_thinking",
775
+ use_mcts=False,
776
+ n_steps=2,
777
+ **kwargs,
778
+ )
779
+
780
+ def generate_deep(
781
+ self,
782
+ trigger: str = "",
783
+ **kwargs,
784
+ ) -> GenerationResult:
785
+ """Generate with deepest reasoning (thinking, MCTS, more steps).
786
+
787
+ Convenience wrapper for complex reasoning tasks.
788
+
789
+ Args:
790
+ trigger: The trigger question or topic.
791
+ **kwargs: Additional arguments passed to generate().
792
+
793
+ Returns:
794
+ GenerationResult with the narrative.
795
+ """
796
+ method = "anchored" if self._has_anchored_decoder else "ddim"
797
+ return self.generate(
798
+ trigger=trigger,
799
+ method=method,
800
+ force_thinking_mode="thinking",
801
+ use_mcts=True,
802
+ n_steps=5 if method == "anchored" else 100,
803
+ **kwargs,
804
+ )