File size: 5,802 Bytes
1a3952b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""AAM Diffusion LLM — Speculative Decoder

Draft model (graph encoder quick prediction) generates candidates,
main diffusion model verifies and accepts/rejects.

For AAM, the graph encoder can serve as the draft model since
it already produces a quick prediction of the narrative.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class SpeculativeConfig:
    d_model: int = 768
    d_vocab: int = 32000
    n_draft_tokens: int = 5
    acceptance_threshold: float = 0.1


class SpeculativeDecoder(nn.Module):
    """Speculative Decoder for AAM.
    
    Uses graph encoder as draft model, diffusion model as verifier.
    1. Draft: graph encoder produces quick token predictions
    2. Verify: diffusion model evaluates each prediction
    3. Accept/Reject: keep tokens that pass verification
    """

    def __init__(self, config: Optional[SpeculativeConfig] = None) -> None:
        super().__init__()
        self.config = config or SpeculativeConfig()
        self.d_model = self.config.d_model
        self.d_vocab = self.config.d_vocab

        # Draft head (lightweight, from graph conditioning)
        self.draft_head = nn.Sequential(
            nn.Linear(self.d_model, self.d_model // 2, bias=False),
            nn.SiLU(),
            nn.Linear(self.d_model // 2, self.d_vocab, bias=False),
        )

        # Verification projection
        self.verify_proj = nn.Sequential(
            nn.Linear(self.d_model, self.d_model, bias=False),
            nn.SiLU(),
            nn.Linear(self.d_model, self.d_vocab, bias=False),
        )

    def draft(
        self,
        graph_hidden: torch.Tensor,
        n_tokens: int = 5,
        temperature: float = 1.0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate draft tokens from graph conditioning.
        
        Args:
            graph_hidden: Graph encoder output (batch, n_nodes, d_model)
            n_tokens: Number of draft tokens to generate
            temperature: Sampling temperature
            
        Returns:
            Tuple (draft_token_ids, draft_log_probs)
        """
        batch_size = graph_hidden.shape[0]
        device = graph_hidden.device

        # Use mean-pooled graph representation
        pooled = graph_hidden.mean(dim=1)  # (batch, d_model)

        all_tokens = []
        all_log_probs = []

        current = pooled
        for _ in range(n_tokens):
            logits = self.draft_head(current) / temperature
            log_probs = F.log_softmax(logits, dim=-1)
            probs = torch.exp(log_probs)
            token_ids = torch.multinomial(probs, 1).squeeze(-1)

            all_tokens.append(token_ids)
            selected_log_probs = log_probs.gather(-1, token_ids.unsqueeze(-1)).squeeze(-1)
            all_log_probs.append(selected_log_probs)

        draft_token_ids = torch.stack(all_tokens, dim=1)  # (batch, n_tokens)
        draft_log_probs = torch.stack(all_log_probs, dim=1)  # (batch, n_tokens)

        return draft_token_ids, draft_log_probs

    def verify(
        self,
        draft_token_ids: torch.Tensor,
        main_logits: torch.Tensor,
        draft_log_probs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Verify draft tokens against main model logits.
        
        Args:
            draft_token_ids: Draft token IDs (batch, n_tokens)
            main_logits: Main model logits (batch, n_tokens, d_vocab)
            draft_log_probs: Draft log probs (batch, n_tokens)
            
        Returns:
            Tuple (accepted_mask, verified_token_ids)
        """
        main_log_probs = F.log_softmax(main_logits, dim=-1)
        selected_main_log_probs = main_log_probs.gather(-1, draft_token_ids.unsqueeze(-1)).squeeze(-1)

        # Accept if main model's probability is close to or higher than draft
        ratio = torch.exp(selected_main_log_probs - draft_log_probs)
        accepted = (ratio >= (1.0 - self.config.acceptance_threshold)).float()

        # Where rejected, sample from main model
        rejected_mask = (accepted == 0).bool()
        if rejected_mask.any():
            main_probs = torch.exp(main_log_probs)
            resampled = torch.multinomial(
                main_probs.view(-1, self.d_vocab), 1
            ).view(draft_token_ids.shape)
            verified = torch.where(rejected_mask, resampled, draft_token_ids)
        else:
            verified = draft_token_ids

        return accepted, verified

    def forward(
        self,
        graph_hidden: torch.Tensor,
        main_model_fn=None,
        n_tokens: Optional[int] = None,
    ) -> Tuple[torch.Tensor, dict]:
        """Full speculative decoding pipeline.
        
        Args:
            graph_hidden: Graph encoder output
            main_model_fn: Callable that takes token_ids → logits
            n_tokens: Number of draft tokens
            
        Returns:
            Tuple (verified_token_ids, info_dict)
        """
        n_tokens = n_tokens or self.config.n_draft_tokens

        # Draft
        draft_ids, draft_log_probs = self.draft(graph_hidden, n_tokens)

        if main_model_fn is not None:
            # Verify with main model
            main_logits = main_model_fn(draft_ids)
            accepted, verified = self.verify(draft_ids, main_logits, draft_log_probs)

            info = {
                "n_draft": n_tokens,
                "n_accepted": accepted.sum(dim=-1).mean().item(),
                "acceptance_rate": accepted.mean().item(),
            }
            return verified, info
        else:
            # No verification — just return draft
            return draft_ids, {"n_draft": n_tokens, "acceptance_rate": 1.0}