File size: 6,928 Bytes
77d27ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""AAM Diffusion LLM — Thinking Toggle

Detects whether input needs deep reasoning (thinking) or quick response
(non-thinking). AAM-specific: simple factual query = 2 anchored steps,
complex reasoning = 5-10 steps + MCTS.
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from enum import Enum
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class ThinkingMode(Enum):
    THINKING = "thinking"
    NON_THINKING = "non_thinking"


class TaskType(Enum):
    SEQUENTIAL = "sequential"
    REASONING = "reasoning"
    FACTUAL = "factual"
    CREATIVE = "creative"
    ANOMALY_RESOLUTION = "anomaly_resolution"


@dataclass
class ThinkingAssessment:
    mode: ThinkingMode
    complexity_score: torch.Tensor
    task_type_probs: torch.Tensor
    dominant_task: TaskType
    depth_multiplier: torch.Tensor
    confidence: torch.Tensor
    thinking_score: Optional[torch.Tensor] = None


class ThinkingToggle(nn.Module):
    """Thinking/Non-Thinking Toggle for AAM Diffusion LLM."""

    NUM_TASK_TYPES = len(TaskType)

    def __init__(self, d_model: int, threshold: float = 0.5) -> None:
        super().__init__()
        self.d_model = d_model
        self.threshold = threshold

        self.complexity_scorer = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.SiLU(),
            nn.Linear(d_model // 2, d_model // 4),
            nn.SiLU(),
            nn.Linear(d_model // 4, 1),
            nn.Sigmoid(),
        )

        self.task_classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.SiLU(),
            nn.Linear(d_model // 2, self.NUM_TASK_TYPES),
        )

        self.context_integrator = nn.Sequential(
            nn.Linear(1 + self.NUM_TASK_TYPES, d_model // 4),
            nn.SiLU(),
            nn.Linear(d_model // 4, 1),
            nn.Sigmoid(),
        )

        self.depth_min = 0.3
        self.depth_max = 2.0

        self.register_buffer("_force_mode_code", torch.tensor(-1, dtype=torch.long), persistent=True)

    def forward(self, x: torch.Tensor, force_mode: Optional[ThinkingMode] = None) -> ThinkingAssessment:
        if x.dim() != 3:
            raise ValueError(f"Input must be 3D [batch, seq, d_model], got {x.dim()}D")

        complexity = self.complexity_scorer(x).squeeze(-1)
        task_logits = self.task_classifier(x)
        task_probs = F.softmax(task_logits, dim=-1)

        mean_complexity = complexity.mean(dim=-1, keepdim=True)
        mean_task_probs = task_probs.mean(dim=1)

        context_input = torch.cat([mean_complexity, mean_task_probs], dim=-1)
        thinking_score = self.context_integrator(context_input).squeeze(-1)

        # v2.3.0: Use force_mode kwarg if provided (thread-safe, no state mutation).
        # Falls back to _get_force_mode() for backward compatibility.
        if force_mode is not None:
            mode = force_mode
        else:
            persistent_mode = self._get_force_mode()
            if persistent_mode is not None:
                mode = persistent_mode
            else:
                # v1.8.0: Straight-through estimator for differentiable depth_multiplier.
                # Forward pass uses hard threshold for control flow (must be non-differentiable),
                # but depth_multiplier remains fully differentiable through soft blending.
                avg_score_val = thinking_score.mean().item()
                mode = ThinkingMode.THINKING if avg_score_val > self.threshold else ThinkingMode.NON_THINKING

        overall_task_probs = task_probs.mean(dim=(0, 1))
        dominant_task_idx = overall_task_probs.argmax().item()
        dominant_task = list(TaskType)[dominant_task_idx]

        avg_thinking_score = thinking_score
        temperature = 5.0
        mode_weight = torch.sigmoid(temperature * (avg_thinking_score - self.threshold))

        thinking_depth = self.depth_min + (self.depth_max - self.depth_min) * avg_thinking_score
        non_thinking_depth = self.depth_min + 0.2 * avg_thinking_score
        depth_multiplier = mode_weight * thinking_depth + (1.0 - mode_weight) * non_thinking_depth

        confidence = 1.0 - (avg_thinking_score - self.threshold).abs() / max(self.threshold, 1.0 - self.threshold)
        confidence = confidence.clamp(0.0, 1.0)

        return ThinkingAssessment(
            mode=mode,
            complexity_score=complexity,
            task_type_probs=task_probs,
            dominant_task=dominant_task,
            depth_multiplier=depth_multiplier,
            confidence=confidence,
            thinking_score=thinking_score,
        )

    def _get_force_mode(self) -> Optional[ThinkingMode]:
        """Decode the persistent buffer back to a ThinkingMode (or None)."""
        code = int(self._force_mode_code.item())
        if code == -1:
            return None
        elif code == 0:
            return ThinkingMode.NON_THINKING
        elif code == 1:
            return ThinkingMode.THINKING
        else:
            # Corrupted value — reset to auto
            self._force_mode_code.fill_(-1)
            return None

    def set_force_mode(self, mode: Optional[ThinkingMode]) -> None:
        """Force mode, bypassing detection. Set None for automatic detection.

        The mode is persisted via a registered buffer so it survives
        model.state_dict() / model.load_state_dict() round-trips.
        """
        if mode is None:
            self._force_mode_code.fill_(-1)
        elif mode == ThinkingMode.NON_THINKING:
            self._force_mode_code.fill_(0)
        elif mode == ThinkingMode.THINKING:
            self._force_mode_code.fill_(1)
        else:
            raise ValueError(f"Unknown ThinkingMode: {mode!r}")

    def set_threshold(self, threshold: float) -> None:
        """Update complexity threshold.

        Args:
            threshold: New threshold value (0.0 - 1.0)
        """
        if not 0.0 <= threshold <= 1.0:
            raise ValueError(f"Threshold must be between 0.0 and 1.0, got {threshold}")
        self.threshold = threshold

    def get_thinking_mask(self, assessment: ThinkingAssessment, seq_len: int) -> torch.Tensor:
        """Create binary mask marking which tokens need thinking.

        Args:
            assessment: Assessment result from forward
            seq_len: Sequence length

        Returns:
            Mask [batch, seq] — 1.0 for thinking, 0.0 for non-thinking
        """
        mask = (assessment.complexity_score > self.threshold).float()
        return mask

    def get_depth_schedule(self, assessment: ThinkingAssessment) -> torch.Tensor:
        complexity = assessment.complexity_score
        depth = self.depth_min + (self.depth_max - self.depth_min) * complexity
        if assessment.mode == ThinkingMode.NON_THINKING:
            depth = depth.clamp(max=self.depth_min + 0.3)
        return depth