File size: 11,337 Bytes
27f26fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""
reverse_process.py  β€” Final Correct Version
=============================================

KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
in inference.py, which is what produced BERTScore 0.75 at validation.

CRITICAL BUG IN PREVIOUS VERSION:
  We passed inference_mode=True to the model, but the model was NEVER
  called with inference_mode=True during training or validation.
  run_inference() (the validated path) does:
      model(input_ids, x0_est, t, x0_hint=hint)
  β†’ inference_mode defaults to False.

  With inference_mode=True the model does two things differently:
    1. tgt_pad_mask = None  (training used tgt_pad_mask = tgt==PAD)
    2. Skips q_sample at t=0 (training always called q_sample)
  The model was never trained to handle these conditions β†’ garbage output.

  Fix: do NOT pass inference_mode. Let it default to False, exactly
  as run_inference() did.

BUGS FIXED (vs original reverse_process.py)
--------------------------------------------
BUG 1  generate_beam() used for D3PM β†’ all-Ṛ repetition.
       Use generate() (iterative refinement) from app1.py instead.
BUG 2  apply_diversity_penalty used logits.var() β†’ noise injection.
       Fixed to logits - penalty * logits.mean(dim=1) β€” global suppression.
BUG 3  x0_hint (self-conditioning) never passed to model.
       Fixed: generate() passes x0_hint=hint every step.
BUG 4  params not forwarded from generate_beam() to p_sample_step().
       Fixed in generate_beam() (kept for reference, not for production use).
"""

import torch
import torch.nn.functional as F


class ReverseDiffusion:

    def __init__(self, scheduler):
        self.scheduler = scheduler

        # Attribute-style defaults for backward compat with any code
        # that sets  reverse_diffusion.temperature = 0.9 etc.
        # generate() prefers explicit kwargs and falls back to these.
        self.temperature        = 0.75
        self.repetition_penalty = 1.15
        self.diversity_penalty  = 0.0
        self.top_k              = 50

    # ------------------------------------------------------------------ #
    #  generate  β€” CORRECT D3PM iterative refinement                      #
    #  Exact equivalent of run_inference() in inference.py                #
    # ------------------------------------------------------------------ #
    def generate(
        self,
        model,
        condition,
        num_steps          = None,
        temperature        = None,
        top_k              = None,
        repetition_penalty = None,
        diversity_penalty  = None,
    ):
        """
        D3PM iterative refinement β€” identical to run_inference() in inference.py,
        which is the validated path (BERTScore 0.75).

        Algorithm:
          x0_est = all [MASK]
          for t = T-1 down to 0:
            logits = model(src, x0_est, t, x0_hint=hint)
                     ↑ inference_mode NOT passed (defaults to False)
                     ↑ this exactly matches training/validation
            apply penalties, temperature, top_k
            if t > 0: x0_est = multinomial(softmax(logits))   ← stochastic
            if t = 0: x0_est = argmax(softmax(logits))         ← deterministic
            hint = x0_est
        """
        # Resolve: explicit kwarg > object attribute
        temperature        = temperature        if temperature        is not None else self.temperature
        top_k              = top_k              if top_k              is not None else self.top_k
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
        diversity_penalty  = diversity_penalty  if diversity_penalty  is not None else self.diversity_penalty

        if num_steps is None:
            num_steps = self.scheduler.num_timesteps

        device = condition.device
        if condition.dim() == 1:
            condition = condition.unsqueeze(0)
        B, L = condition.shape

        T         = self.scheduler.num_timesteps
        step_size = max(1, T // num_steps)
        timesteps = list(range(T - 1, -1, -step_size))
        if timesteps[-1] != 0:
            timesteps.append(0)

        mask_id = model.mask_token_id
        x0_est  = torch.full((B, L), mask_id, dtype=torch.long, device=device)
        hint    = None

        model.eval()
        with torch.no_grad():
            for step_idx, t_val in enumerate(timesteps):
                t       = torch.full((B,), t_val, dtype=torch.long, device=device)
                is_last = (step_idx == len(timesteps) - 1)

                # ── CRITICAL: do NOT pass inference_mode ──────────────────
                # inference_mode defaults to False inside SanskritModel /
                # D3PMCrossAttention. This matches run_inference() exactly.
                # Passing inference_mode=True changes tgt_pad_mask and
                # q_sample behaviour β€” the model was never trained for that.
                logits, _ = model(condition, x0_est, t, x0_hint=hint)

                # Repetition penalty
                if repetition_penalty != 1.0:
                    logits = apply_repetition_penalty(
                        logits, x0_est, repetition_penalty
                    )

                # Diversity penalty (correct: global mean suppression)
                if diversity_penalty > 0.0:
                    logits = apply_diversity_penalty(logits, diversity_penalty)

                logits = logits / max(temperature, 1e-5)

                if top_k > 0:
                    logits = top_k_filter(logits, top_k)

                probs = F.softmax(logits, dim=-1)

                # Stochastic at every step except the last (argmax at t=0)
                if is_last:
                    x0_est = torch.argmax(probs, dim=-1)
                else:
                    x0_est = batch_multinomial(probs)

                hint = x0_est

        return x0_est   # (B, L)

    # ------------------------------------------------------------------ #
    #  p_sample_step  β€” used by generate_beam (not for production)        #
    # ------------------------------------------------------------------ #
    def p_sample_step(
        self,
        model,
        x_t,
        t,
        condition,
        beam_width         = 3,
        temperature        = 1.0,
        repetition_penalty = 1.2,
        diversity_penalty  = 0.3,
        x0_hint            = None,
    ):
        with torch.no_grad():
            if x_t.dim() == 1:       x_t       = x_t.unsqueeze(0)
            if condition.dim() == 1: condition  = condition.unsqueeze(0)
            if t.dim() == 0:         t          = t.unsqueeze(0)
            if t.shape[0] != x_t.shape[0]:
                t = t.expand(x_t.shape[0])

            # No inference_mode β€” matches training convention
            logits, _ = model(condition, x_t, t, x0_hint=x0_hint)

            logits = logits / max(temperature, 1e-5)

            if repetition_penalty != 1.0:
                logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
            if diversity_penalty > 0.0:
                logits = apply_diversity_penalty(logits, diversity_penalty)

            probs = F.softmax(logits, dim=-1)
            B, L, V = probs.shape

            topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
            candidates = []
            for k in range(beam_width):
                next_tokens = topk_ids[:, :, k]
                score       = torch.log(topk_probs[:, :, k] + 1e-9).sum()
                candidates.append((next_tokens, score))
            return candidates

    # ------------------------------------------------------------------ #
    #  generate_beam  β€” kept for reference; NOT the correct D3PM method   #
    # ------------------------------------------------------------------ #
    def generate_beam(
        self,
        model,
        condition,
        beam_width         = 3,
        num_steps          = None,
        temperature        = None,
        repetition_penalty = None,
        diversity_penalty  = None,
    ):
        """
        WARNING: do NOT call this from app1.py for D3PM generation.
        generate_beam() forces every position to the same top-k token
        β†’ all-Ṛ / all-rud repetition. Use generate() instead.
        Kept only for experimental reference.
        """
        temperature        = temperature        if temperature        is not None else self.temperature
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
        diversity_penalty  = diversity_penalty  if diversity_penalty  is not None else self.diversity_penalty
        if num_steps is None:
            num_steps = self.scheduler.num_timesteps

        device = condition.device
        if condition.dim() == 1: condition = condition.unsqueeze(0)
        B, L = condition.shape

        x_init = torch.full((B, L), fill_value=model.mask_token_id,
                            dtype=torch.long, device=device)
        beams     = [(x_init, 0.0)]
        best_hint = None

        for step in reversed(range(num_steps)):
            t_tensor  = torch.full((B,), step, dtype=torch.long, device=device)
            new_beams = []
            for x_t, score in beams:
                candidates = self.p_sample_step(
                    model, x_t, t_tensor, condition,
                    beam_width         = beam_width,
                    temperature        = temperature,
                    repetition_penalty = repetition_penalty,
                    diversity_penalty  = diversity_penalty,
                    x0_hint            = best_hint,
                )
                for tokens, new_score in candidates:
                    new_beams.append((tokens, score + new_score.item()))

            new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
            beams     = new_beams[:beam_width]
            best_hint = beams[0][0]

        return beams[0][0]   # (B, L)


# ── Penalty helpers ────────────────────────────────────────────────────────

def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
    """Down-weight tokens already present in the sequence."""
    for b in range(logits.shape[0]):
        for token_id in set(prev_tokens[b].tolist()):
            if token_id > 4:
                logits[b, :, token_id] = logits[b, :, token_id] / penalty
    return logits


def apply_diversity_penalty(logits, penalty=0.3):
    """
    Correct diversity penalty: suppress globally dominant tokens.
    logits -= penalty * mean(logits, dim=1)  [sequence dimension]
    """
    global_mean = logits.mean(dim=1, keepdim=True)   # [B, 1, V]
    return logits - penalty * global_mean


def top_k_filter(logits, k):
    B, L, V = logits.shape
    if k >= V: return logits
    topk_vals, _ = torch.topk(logits, k, dim=-1)
    return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))


def batch_multinomial(probs):
    B, L, V = probs.shape
    flat = probs.view(B * L, V) + 1e-9
    flat = flat / flat.sum(dim=-1, keepdim=True)
    return torch.multinomial(flat, 1).squeeze(-1).view(B, L)