aam-diffusion-v1 / diffusion_llm /model /thinking_toggle.py
Wolfvin's picture
Upload diffusion_llm/model/thinking_toggle.py with huggingface_hub
77d27ba verified
"""AAM Diffusion LLM — Thinking Toggle
Detects whether input needs deep reasoning (thinking) or quick response
(non-thinking). AAM-specific: simple factual query = 2 anchored steps,
complex reasoning = 5-10 steps + MCTS.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from enum import Enum
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class ThinkingMode(Enum):
THINKING = "thinking"
NON_THINKING = "non_thinking"
class TaskType(Enum):
SEQUENTIAL = "sequential"
REASONING = "reasoning"
FACTUAL = "factual"
CREATIVE = "creative"
ANOMALY_RESOLUTION = "anomaly_resolution"
@dataclass
class ThinkingAssessment:
mode: ThinkingMode
complexity_score: torch.Tensor
task_type_probs: torch.Tensor
dominant_task: TaskType
depth_multiplier: torch.Tensor
confidence: torch.Tensor
thinking_score: Optional[torch.Tensor] = None
class ThinkingToggle(nn.Module):
"""Thinking/Non-Thinking Toggle for AAM Diffusion LLM."""
NUM_TASK_TYPES = len(TaskType)
def __init__(self, d_model: int, threshold: float = 0.5) -> None:
super().__init__()
self.d_model = d_model
self.threshold = threshold
self.complexity_scorer = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.SiLU(),
nn.Linear(d_model // 2, d_model // 4),
nn.SiLU(),
nn.Linear(d_model // 4, 1),
nn.Sigmoid(),
)
self.task_classifier = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.SiLU(),
nn.Linear(d_model // 2, self.NUM_TASK_TYPES),
)
self.context_integrator = nn.Sequential(
nn.Linear(1 + self.NUM_TASK_TYPES, d_model // 4),
nn.SiLU(),
nn.Linear(d_model // 4, 1),
nn.Sigmoid(),
)
self.depth_min = 0.3
self.depth_max = 2.0
self.register_buffer("_force_mode_code", torch.tensor(-1, dtype=torch.long), persistent=True)
def forward(self, x: torch.Tensor, force_mode: Optional[ThinkingMode] = None) -> ThinkingAssessment:
if x.dim() != 3:
raise ValueError(f"Input must be 3D [batch, seq, d_model], got {x.dim()}D")
complexity = self.complexity_scorer(x).squeeze(-1)
task_logits = self.task_classifier(x)
task_probs = F.softmax(task_logits, dim=-1)
mean_complexity = complexity.mean(dim=-1, keepdim=True)
mean_task_probs = task_probs.mean(dim=1)
context_input = torch.cat([mean_complexity, mean_task_probs], dim=-1)
thinking_score = self.context_integrator(context_input).squeeze(-1)
# v2.3.0: Use force_mode kwarg if provided (thread-safe, no state mutation).
# Falls back to _get_force_mode() for backward compatibility.
if force_mode is not None:
mode = force_mode
else:
persistent_mode = self._get_force_mode()
if persistent_mode is not None:
mode = persistent_mode
else:
# v1.8.0: Straight-through estimator for differentiable depth_multiplier.
# Forward pass uses hard threshold for control flow (must be non-differentiable),
# but depth_multiplier remains fully differentiable through soft blending.
avg_score_val = thinking_score.mean().item()
mode = ThinkingMode.THINKING if avg_score_val > self.threshold else ThinkingMode.NON_THINKING
overall_task_probs = task_probs.mean(dim=(0, 1))
dominant_task_idx = overall_task_probs.argmax().item()
dominant_task = list(TaskType)[dominant_task_idx]
avg_thinking_score = thinking_score
temperature = 5.0
mode_weight = torch.sigmoid(temperature * (avg_thinking_score - self.threshold))
thinking_depth = self.depth_min + (self.depth_max - self.depth_min) * avg_thinking_score
non_thinking_depth = self.depth_min + 0.2 * avg_thinking_score
depth_multiplier = mode_weight * thinking_depth + (1.0 - mode_weight) * non_thinking_depth
confidence = 1.0 - (avg_thinking_score - self.threshold).abs() / max(self.threshold, 1.0 - self.threshold)
confidence = confidence.clamp(0.0, 1.0)
return ThinkingAssessment(
mode=mode,
complexity_score=complexity,
task_type_probs=task_probs,
dominant_task=dominant_task,
depth_multiplier=depth_multiplier,
confidence=confidence,
thinking_score=thinking_score,
)
def _get_force_mode(self) -> Optional[ThinkingMode]:
"""Decode the persistent buffer back to a ThinkingMode (or None)."""
code = int(self._force_mode_code.item())
if code == -1:
return None
elif code == 0:
return ThinkingMode.NON_THINKING
elif code == 1:
return ThinkingMode.THINKING
else:
# Corrupted value — reset to auto
self._force_mode_code.fill_(-1)
return None
def set_force_mode(self, mode: Optional[ThinkingMode]) -> None:
"""Force mode, bypassing detection. Set None for automatic detection.
The mode is persisted via a registered buffer so it survives
model.state_dict() / model.load_state_dict() round-trips.
"""
if mode is None:
self._force_mode_code.fill_(-1)
elif mode == ThinkingMode.NON_THINKING:
self._force_mode_code.fill_(0)
elif mode == ThinkingMode.THINKING:
self._force_mode_code.fill_(1)
else:
raise ValueError(f"Unknown ThinkingMode: {mode!r}")
def set_threshold(self, threshold: float) -> None:
"""Update complexity threshold.
Args:
threshold: New threshold value (0.0 - 1.0)
"""
if not 0.0 <= threshold <= 1.0:
raise ValueError(f"Threshold must be between 0.0 and 1.0, got {threshold}")
self.threshold = threshold
def get_thinking_mask(self, assessment: ThinkingAssessment, seq_len: int) -> torch.Tensor:
"""Create binary mask marking which tokens need thinking.
Args:
assessment: Assessment result from forward
seq_len: Sequence length
Returns:
Mask [batch, seq] — 1.0 for thinking, 0.0 for non-thinking
"""
mask = (assessment.complexity_score > self.threshold).float()
return mask
def get_depth_schedule(self, assessment: ThinkingAssessment) -> torch.Tensor:
complexity = assessment.complexity_score
depth = self.depth_min + (self.depth_max - self.depth_min) * complexity
if assessment.mode == ThinkingMode.NON_THINKING:
depth = depth.clamp(max=self.depth_min + 0.3)
return depth