Wolfvin's picture
Upload diffusion_llm/inference/generator.py with huggingface_hub
3a8397a verified
"""
AAM Diffusion LLM — Inference Generator (v2.0)
Generates natural language narratives from graph conditioning
using the trained diffusion model.
v2.0 Upgrades:
- ThinkingToggle for adaptive inference (thinking vs non-thinking)
- Anchored decoding method (2-3 steps instead of 50)
- Flow matching method (velocity-based 2-3 step sampling)
- MCTS integration for complex reasoning tasks
- DualMemorySystem for long narrative generation
- Full backward compatibility with v1.0 generation
The generation process (v2.0 Anchored):
1. Encode graph conditioning (evidence, anomalies, reasoning)
2. [Optional] ThinkingToggle assesses complexity
3. [Optional] MCTS explores narrative arrangements for complex inputs
4. Generate via anchored decoding (2-3 refinement steps)
5. Convert denoised embeddings to token IDs
6. Detokenize to natural language text
The generation process (Legacy DDPM/DDIM):
1. Encode graph conditioning
2. Start from pure noise in the latent space
3. Iteratively denoise for N steps
4. Convert denoised embeddings to token IDs
5. Detokenize to natural language text
Analogi: Seperti Jin Soun akhirnya "berbicara" — dari
pikiran yang kabur (noise) menjadi kata-kata yang jelas
(denoised narrative). Di v2.0, Jin Soun sekarang bisa
memilih: berbicara cepat untuk hal sederhana (non-thinking),
atau berpikir dalam untuk masalah rumit (thinking + MCTS).
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
import torch
from diffusion_llm.config.model_config import AamDiffusionConfig, InferenceConfig
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
logger = logging.getLogger(__name__)
@dataclass
class GenerationResult:
"""Result from a generation call.
Contains the generated narrative plus metadata about
how it was generated, for traceability.
"""
narrative: str
"""Generated narrative text."""
token_ids: list[int] = field(default_factory=list)
"""Generated token IDs."""
n_diffusion_steps: int = 0
"""Number of denoising steps used."""
generation_time_s: float = 0.0
"""Wall-clock generation time."""
model_name: str = ""
"""Name of the model used."""
evidence_used: list[str] = field(default_factory=list)
"""Evidence nodes that were provided as conditioning."""
confidence: float = 0.0
"""Overall confidence of the generation."""
language: str = "id"
"""Output language."""
# v2.0 metadata
sampling_method: str = "ddim"
"""Sampling method used ('anchored', 'flow_matching', 'ddpm', 'ddim')."""
thinking_mode: str = ""
"""ThinkingToggle mode: 'thinking', 'non_thinking', or '' if disabled."""
complexity_score: float = 0.0
"""Complexity score from ThinkingToggle (0.0 if disabled)."""
mcts_used: bool = False
"""Whether MCTS reasoning was used."""
memory_stats: Dict[str, object] = field(default_factory=dict)
"""DualMemory statistics at generation time."""
def to_dict(self) -> dict:
"""Serialize to dictionary."""
result = {
"narrative": self.narrative,
"n_diffusion_steps": self.n_diffusion_steps,
"generation_time_s": round(self.generation_time_s, 3),
"model_name": self.model_name,
"evidence_used": self.evidence_used,
"confidence": round(self.confidence, 3),
"language": self.language,
"sampling_method": self.sampling_method,
}
if self.thinking_mode:
result["thinking_mode"] = self.thinking_mode
result["complexity_score"] = round(self.complexity_score, 3)
if self.mcts_used:
result["mcts_used"] = True
if self.memory_stats:
result["memory_stats"] = self.memory_stats
return result
class AamGenerator:
"""Generate narratives from graph conditioning using the trained model (v2.0).
This is the main inference interface. It takes graph-structured
data (from the RSVS Knowledge Graph) and produces natural
language narratives through the diffusion denoising process.
v2.0 features:
- Adaptive compute via ThinkingToggle
- Fast anchored decoding (2-3 steps)
- Flow matching decoding
- MCTS for complex reasoning
- Dual memory for long narratives
Usage:
# Load model and tokenizer
config = AamDiffusionConfig.from_json("config.json")
model = AamDiffusionModel.load("best.pt")
tokenizer = AamTokenizer.load("tokenizer.json")
# Create generator
generator = AamGenerator(model, tokenizer, config)
# Generate narrative (v2.0 anchored decoding)
result = generator.generate(
trigger="Siapa yang mencuri Snow Plum Pill?",
evidence_nodes=["hefei", "diancang", "ju_jangmok"],
anomalies=["no external pill consumption"],
reasoning_steps=["Diancang pair was in Hefei before theft"],
method="anchored",
)
print(result.narrative)
# Generate narrative (legacy DDIM)
result = generator.generate(
trigger="Summary of events",
evidence_nodes=["event_a", "event_b"],
method="ddim",
)
print(result.narrative)
Args:
model: Trained AamDiffusionModel.
tokenizer: Trained AamTokenizer.
config: AamDiffusionConfig with inference settings.
"""
def __init__(
self,
model: AamDiffusionModel,
tokenizer: AamTokenizer,
config: AamDiffusionConfig,
):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.inference_config = config.inference
# Device
self.device = next(model.parameters()).device
# Set model to eval mode
self.model.eval()
# Feature detection
self._has_anchored_decoder = hasattr(model, "output_head")
self._has_thinking_toggle = hasattr(model, "thinking_toggle")
self._has_flow_matching = hasattr(model, "flow_matching_decoder")
self._has_mcts = hasattr(model, "mcts_reasoner")
self._has_dual_memory = hasattr(model, "dual_memory")
self._has_evoformer = hasattr(model, "evoformer")
logger.info(
"AamGenerator v2.0 initialized. Features: anchored=%s, thinking=%s, "
"flow=%s, mcts=%s, memory=%s, evoformer=%s",
self._has_anchored_decoder,
self._has_thinking_toggle,
self._has_flow_matching,
self._has_mcts,
self._has_dual_memory,
self._has_evoformer,
)
@torch.no_grad()
def generate(
self,
trigger: str = "",
evidence_nodes: Optional[list[str]] = None,
compositions: Optional[list[str]] = None,
confidence_map: Optional[dict[str, float]] = None,
anomalies: Optional[list[str]] = None,
reasoning_steps: Optional[list[str]] = None,
source_trust: float = 1.0,
n_steps: Optional[int] = None,
temperature: Optional[float] = None,
language: Optional[str] = None,
max_sentences: Optional[int] = None,
method: Optional[str] = None,
use_mcts: Optional[bool] = None,
force_thinking_mode: Optional[str] = None,
) -> GenerationResult:
"""Generate a narrative from graph conditioning.
This is the main generation method. It:
1. Tokenizes the graph conditioning data
2. Encodes it through the graph encoder
3. [v2.0] Optionally assesses thinking complexity
4. [v2.0] Optionally runs MCTS for complex reasoning
5. Generates via the selected sampling method
6. Converts the result to text
Args:
trigger: The trigger question or topic.
evidence_nodes: Evidence node descriptions.
compositions: Composition descriptions.
confidence_map: Node confidence scores.
anomalies: Anomaly descriptions.
reasoning_steps: Reasoning step descriptions.
source_trust: Source trust score.
n_steps: Override number of denoising steps.
temperature: Override sampling temperature.
language: Override output language.
max_sentences: Maximum sentences in output.
method: Sampling method — 'anchored', 'flow_matching',
'ddpm', 'ddim', or None (uses config default).
use_mcts: Override whether to use MCTS. None = auto-decide
based on ThinkingToggle assessment.
force_thinking_mode: Force thinking mode ('thinking' or
'non_thinking'). None = auto-decide.
Returns:
GenerationResult with the narrative and metadata.
"""
start_time = time.time()
# Use config defaults if not overridden
n_steps = n_steps or self.inference_config.n_steps
temperature = temperature or self.inference_config.temperature
language = language or self.inference_config.language
max_sentences = max_sentences or self.inference_config.max_output_sentences
# Determine sampling method
if method is None:
# Default to anchored if available, else use config
if self._has_anchored_decoder:
method = "anchored"
else:
method = self.config.diffusion.sampling_method
# Validate method availability
if method == "anchored" and not self._has_anchored_decoder:
logger.warning(
"Anchored decoding requested but ContinuousOutputHead not "
"available. Falling back to '%s'.",
self.config.diffusion.sampling_method,
)
method = self.config.diffusion.sampling_method
if method == "flow_matching" and not self._has_flow_matching:
logger.warning(
"Flow matching requested but FlowMatchingDecoder not "
"available. Falling back to '%s'.",
self.config.diffusion.sampling_method,
)
method = self.config.diffusion.sampling_method
# --- Step 1: Tokenize graph conditioning ---
(
evidence_ids_tensor,
evidence_conf_tensor,
anomaly_ids_tensor,
anomaly_conf_tensor,
reasoning_ids_tensor,
reasoning_conf_tensor,
composition_ids_tensor,
composition_conf_tensor,
) = self._tokenize_graph_conditioning(
evidence_nodes=evidence_nodes,
compositions=compositions,
confidence_map=confidence_map,
anomalies=anomalies,
reasoning_steps=reasoning_steps,
source_trust=source_trust,
)
source_trust_tensor = torch.tensor(
[source_trust], dtype=torch.float32, device=self.device
)
# --- Step 2: Encode graph conditioning ---
graph_cond = self.model.graph_encoder(
evidence_ids=evidence_ids_tensor,
evidence_confidence=evidence_conf_tensor,
anomaly_ids=anomaly_ids_tensor,
anomaly_confidence=anomaly_conf_tensor,
reasoning_ids=reasoning_ids_tensor,
reasoning_confidence=reasoning_conf_tensor,
composition_ids=composition_ids_tensor,
composition_confidence=composition_conf_tensor,
source_trust=source_trust_tensor,
)
# --- Step 3: ThinkingToggle assessment ---
thinking_mode_str = ""
complexity_score = 0.0
assessment = None
if self._has_thinking_toggle:
assessment = self._assess_complexity(
graph_cond, force_thinking_mode=force_thinking_mode
)
if assessment is not None:
thinking_mode_str = assessment.mode.value
complexity_score = (
assessment.complexity_score.mean().item()
if assessment.complexity_score.numel() > 0
else 0.0
)
# Adaptive step count based on thinking assessment
if method == "anchored":
depth_mult = assessment.depth_multiplier.mean().item()
n_steps = max(2, min(5, int(3 * depth_mult)))
elif method in ("ddpm", "ddim"):
depth_mult = assessment.depth_multiplier.mean().item()
n_steps = max(
10,
int(self.inference_config.n_steps * depth_mult),
)
logger.debug(
"ThinkingToggle: mode=%s, complexity=%.3f, "
"depth_mult=%.2f, n_steps=%d",
thinking_mode_str,
complexity_score,
assessment.depth_multiplier.mean().item(),
n_steps,
)
# --- Step 4: MCTS reasoning (for complex inputs) ---
mcts_used = False
mcts_info: Dict[str, Any] = {}
should_use_mcts = self._should_use_mcts(
use_mcts=use_mcts,
assessment=assessment,
method=method,
)
if should_use_mcts:
mcts_result = self._run_mcts_reasoning(graph_cond)
if mcts_result is not None:
mcts_used = True
mcts_info = mcts_result
# --- Step 5: Generate via diffusion denoising ---
shape = (
1,
self.config.model.max_seq_len,
self.config.model.d_model,
)
denoised = self.model.sample(
graph_cond=graph_cond,
n_steps=n_steps,
method=method,
shape=shape,
device=self.device,
temperature=temperature,
)
# --- Step 6: Convert to tokens ---
# Extract graph context for anchored decoder
graph_values = graph_cond.get("values")
graph_context = None
if graph_values is not None:
graph_context = graph_values.mean(dim=1)
token_ids = self.model.embeddings_to_tokens(
denoised,
temperature=temperature,
top_k=self.inference_config.top_k,
graph_context=graph_context,
)
# --- Step 7: Detokenize ---
token_list = token_ids[0].cpu().tolist()
narrative = self.tokenizer.decode(token_list, skip_special=True)
# Truncate to max sentences
if max_sentences:
sentences = self.tokenizer._split_sentences(narrative)
if len(sentences) > max_sentences:
narrative = ". ".join(sentences[:max_sentences]) + "."
generation_time = time.time() - start_time
# Compute average confidence
avg_confidence = source_trust
if confidence_map:
avg_confidence = sum(confidence_map.values()) / len(confidence_map)
# Collect memory stats
mem_stats = self.model.memory_stats() if self._has_dual_memory else {}
# Consolidate memory for future generations
if self._has_dual_memory:
self.model.memory_consolidate()
return GenerationResult(
narrative=narrative,
token_ids=token_list,
n_diffusion_steps=n_steps,
generation_time_s=generation_time,
model_name=self.config.model_name,
evidence_used=evidence_nodes or [],
confidence=avg_confidence,
language=language,
sampling_method=method,
thinking_mode=thinking_mode_str,
complexity_score=complexity_score,
mcts_used=mcts_used,
memory_stats=mem_stats,
)
# ================================================================
# Internal helpers
# ================================================================
def _tokenize_graph_conditioning(
self,
evidence_nodes: Optional[list[str]] = None,
compositions: Optional[list[str]] = None,
confidence_map: Optional[dict[str, float]] = None,
anomalies: Optional[list[str]] = None,
reasoning_steps: Optional[list[str]] = None,
source_trust: float = 1.0,
) -> tuple:
"""Tokenize all graph conditioning data into tensors.
Returns:
Tuple of (evidence_ids, evidence_conf, anomaly_ids,
anomaly_conf, reasoning_ids, reasoning_conf,
composition_ids, composition_conf) tensors.
"""
evidence_ids_tensor = None
evidence_conf_tensor = None
anomaly_ids_tensor = None
anomaly_conf_tensor = None
reasoning_ids_tensor = None
reasoning_conf_tensor = None
composition_ids_tensor = None
composition_conf_tensor = None
max_evidence = self.config.graph_encoder.max_evidence_nodes
max_anomalies = self.config.graph_encoder.max_anomalies
max_reasoning = self.config.graph_encoder.max_reasoning_steps
max_compositions = self.config.graph_encoder.max_compositions
node_len = 32
# Evidence nodes
if evidence_nodes:
evidence_ids_list = []
evidence_conf_list = []
for node in evidence_nodes[:max_evidence]:
ids = self.tokenizer.encode(node, add_special=False)
ids = self.tokenizer.pad_sequence(ids, node_len)
evidence_ids_list.append(ids)
conf = (confidence_map or {}).get(node, 0.7)
evidence_conf_list.append(conf)
while len(evidence_ids_list) < max_evidence:
evidence_ids_list.append([0] * node_len)
evidence_conf_list.append(0.0)
evidence_ids_tensor = torch.tensor(
[evidence_ids_list], dtype=torch.long, device=self.device
)
evidence_conf_tensor = torch.tensor(
[evidence_conf_list], dtype=torch.float32, device=self.device
)
# Compositions
if compositions:
composition_ids_list = []
composition_conf_list = []
for comp in compositions[:max_compositions]:
ids = self.tokenizer.encode(comp, add_special=False)
ids = self.tokenizer.pad_sequence(ids, node_len)
composition_ids_list.append(ids)
composition_conf_list.append(0.8)
while len(composition_ids_list) < max_compositions:
composition_ids_list.append([0] * node_len)
composition_conf_list.append(0.0)
composition_ids_tensor = torch.tensor(
[composition_ids_list], dtype=torch.long, device=self.device
)
composition_conf_tensor = torch.tensor(
[composition_conf_list], dtype=torch.float32, device=self.device
)
# Anomalies
if anomalies:
anomaly_ids_list = []
for anom in anomalies[:max_anomalies]:
ids = self.tokenizer.encode(anom, add_special=False)
ids = self.tokenizer.pad_sequence(ids, node_len)
anomaly_ids_list.append(ids)
while len(anomaly_ids_list) < max_anomalies:
anomaly_ids_list.append([0] * node_len)
anomaly_ids_tensor = torch.tensor(
[anomaly_ids_list], dtype=torch.long, device=self.device
)
anomaly_conf_tensor = torch.full(
(1, max_anomalies),
0.6, dtype=torch.float32, device=self.device,
)
# Reasoning steps
if reasoning_steps:
reasoning_ids_list = []
for step in reasoning_steps[:max_reasoning]:
ids = self.tokenizer.encode(step, add_special=False)
ids = self.tokenizer.pad_sequence(ids, node_len)
reasoning_ids_list.append(ids)
while len(reasoning_ids_list) < max_reasoning:
reasoning_ids_list.append([0] * node_len)
reasoning_ids_tensor = torch.tensor(
[reasoning_ids_list], dtype=torch.long, device=self.device
)
reasoning_conf_tensor = torch.full(
(1, max_reasoning),
0.7, dtype=torch.float32, device=self.device,
)
return (
evidence_ids_tensor,
evidence_conf_tensor,
anomaly_ids_tensor,
anomaly_conf_tensor,
reasoning_ids_tensor,
reasoning_conf_tensor,
composition_ids_tensor,
composition_conf_tensor,
)
def _assess_complexity(
self,
graph_cond: dict[str, torch.Tensor],
force_thinking_mode: Optional[str] = None,
) -> Optional[Any]:
"""Use ThinkingToggle to assess the complexity of the input.
Args:
graph_cond: Graph conditioning dict from encoder.
force_thinking_mode: Force 'thinking' or 'non_thinking'.
Returns:
ThinkingAssessment or None if not available.
"""
if not self._has_thinking_toggle:
return None
from diffusion_llm.model.thinking_toggle import ThinkingMode
# Build a hidden-state-like tensor from graph conditioning
# for the ThinkingToggle to assess
graph_values = graph_cond.get("values")
if graph_values is None:
return None
# Reshape to (batch, seq, d_model) if needed
if graph_values.dim() == 2:
graph_values = graph_values.unsqueeze(0)
force_mode = None
if force_thinking_mode == "thinking":
force_mode = ThinkingMode.THINKING
elif force_thinking_mode == "non_thinking":
force_mode = ThinkingMode.NON_THINKING
try:
assessment = self.model.thinking_toggle(
graph_values, force_mode=force_mode
)
return assessment
except Exception as e:
logger.warning("ThinkingToggle assessment failed: %s", e)
return None
def _should_use_mcts(
self,
use_mcts: Optional[bool],
assessment: Optional[Any],
method: str,
) -> bool:
"""Determine whether MCTS should be used.
Logic:
- If use_mcts is explicitly True/False, use that.
- If use_mcts is None (auto), use MCTS when:
- ThinkingToggle is in THINKING mode, AND
- The task type is REASONING or ANOMALY_RESOLUTION, AND
- MCTS module is available
"""
if not self._has_mcts:
return False
if use_mcts is not None:
return use_mcts
# Auto-decide based on ThinkingToggle
if assessment is None:
return False
from diffusion_llm.model.thinking_toggle import (
ThinkingMode,
TaskType,
)
if assessment.mode != ThinkingMode.THINKING:
return False
# Only use MCTS for reasoning-heavy task types
if assessment.dominant_task in (
TaskType.REASONING,
TaskType.ANOMALY_RESOLUTION,
):
return True
return False
def _run_mcts_reasoning(
self,
graph_cond: dict[str, torch.Tensor],
) -> Optional[Dict[str, Any]]:
"""Run MCTS reasoning on graph conditioning.
Args:
graph_cond: Graph conditioning dict from encoder.
Returns:
Dict with MCTS info, or None if MCTS failed.
"""
graph_values = graph_cond.get("values")
if graph_values is None:
return None
# Reshape for MCTS input
if graph_values.dim() == 2:
graph_values = graph_values.unsqueeze(0)
try:
action_probs, info = self.model.mcts_reasoner(graph_values)
return {
"action_probs_mean": action_probs.mean().item(),
"total_simulations": info.get("total_simulations", 0),
"root_value": info.get("root_value", 0.0),
"entropy": info.get("entropy", 0.0),
}
except Exception as e:
logger.warning("MCTS reasoning failed: %s", e)
return None
# ================================================================
# Batch generation
# ================================================================
def generate_batch(
self,
triggers: list[str],
evidence_nodes_list: Optional[list[list[str]]] = None,
anomalies_list: Optional[list[list[str]]] = None,
**kwargs,
) -> list[GenerationResult]:
"""Generate narratives for multiple triggers.
Args:
triggers: List of trigger questions.
evidence_nodes_list: List of evidence node lists.
anomalies_list: List of anomaly lists.
**kwargs: Additional arguments passed to generate().
Returns:
List of GenerationResult objects.
"""
results = []
for i, trigger in enumerate(triggers):
evidence = evidence_nodes_list[i] if evidence_nodes_list else None
anomalies = anomalies_list[i] if anomalies_list else None
result = self.generate(
trigger=trigger,
evidence_nodes=evidence,
anomalies=anomalies,
**kwargs,
)
results.append(result)
return results
# ================================================================
# Memory management
# ================================================================
def clear_memory(self) -> None:
"""Clear the model's dual memory system.
Useful between independent generation sessions.
"""
if self._has_dual_memory:
self.model.memory_clear()
logger.info("Dual memory cleared.")
def get_memory_stats(self) -> Dict[str, object]:
"""Get current memory statistics.
Returns:
Dict with memory stats, or empty dict if memory disabled.
"""
if self._has_dual_memory:
return self.model.memory_stats()
return {}
# ================================================================
# Convenience methods
# ================================================================
def generate_fast(
self,
trigger: str = "",
**kwargs,
) -> GenerationResult:
"""Generate with fastest settings (non-thinking, anchored, minimal steps).
Convenience wrapper for quick generation.
Args:
trigger: The trigger question or topic.
**kwargs: Additional arguments passed to generate().
Returns:
GenerationResult with the narrative.
"""
return self.generate(
trigger=trigger,
method="anchored",
force_thinking_mode="non_thinking",
use_mcts=False,
n_steps=2,
**kwargs,
)
def generate_deep(
self,
trigger: str = "",
**kwargs,
) -> GenerationResult:
"""Generate with deepest reasoning (thinking, MCTS, more steps).
Convenience wrapper for complex reasoning tasks.
Args:
trigger: The trigger question or topic.
**kwargs: Additional arguments passed to generate().
Returns:
GenerationResult with the narrative.
"""
method = "anchored" if self._has_anchored_decoder else "ddim"
return self.generate(
trigger=trigger,
method=method,
force_thinking_mode="thinking",
use_mcts=True,
n_steps=5 if method == "anchored" else 100,
**kwargs,
)