| """ |
| reverse_process.py β Final Correct Version |
| ============================================= |
| |
| KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference() |
| in inference.py, which is what produced BERTScore 0.75 at validation. |
| |
| CRITICAL BUG IN PREVIOUS VERSION: |
| We passed inference_mode=True to the model, but the model was NEVER |
| called with inference_mode=True during training or validation. |
| run_inference() (the validated path) does: |
| model(input_ids, x0_est, t, x0_hint=hint) |
| β inference_mode defaults to False. |
| |
| With inference_mode=True the model does two things differently: |
| 1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD) |
| 2. Skips q_sample at t=0 (training always called q_sample) |
| The model was never trained to handle these conditions β garbage output. |
| |
| Fix: do NOT pass inference_mode. Let it default to False, exactly |
| as run_inference() did. |
| |
| BUGS FIXED (vs original reverse_process.py) |
| -------------------------------------------- |
| BUG 1 generate_beam() used for D3PM β all-αΉ repetition. |
| Use generate() (iterative refinement) from app1.py instead. |
| BUG 2 apply_diversity_penalty used logits.var() β noise injection. |
| Fixed to logits - penalty * logits.mean(dim=1) β global suppression. |
| BUG 3 x0_hint (self-conditioning) never passed to model. |
| Fixed: generate() passes x0_hint=hint every step. |
| BUG 4 params not forwarded from generate_beam() to p_sample_step(). |
| Fixed in generate_beam() (kept for reference, not for production use). |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| class ReverseDiffusion: |
|
|
| def __init__(self, scheduler): |
| self.scheduler = scheduler |
|
|
| |
| |
| |
| self.temperature = 0.75 |
| self.repetition_penalty = 1.15 |
| self.diversity_penalty = 0.0 |
| self.top_k = 50 |
|
|
| |
| |
| |
| |
| def generate( |
| self, |
| model, |
| condition, |
| num_steps = None, |
| temperature = None, |
| top_k = None, |
| repetition_penalty = None, |
| diversity_penalty = None, |
| ): |
| """ |
| D3PM iterative refinement β identical to run_inference() in inference.py, |
| which is the validated path (BERTScore 0.75). |
| |
| Algorithm: |
| x0_est = all [MASK] |
| for t = T-1 down to 0: |
| logits = model(src, x0_est, t, x0_hint=hint) |
| β inference_mode NOT passed (defaults to False) |
| β this exactly matches training/validation |
| apply penalties, temperature, top_k |
| if t > 0: x0_est = multinomial(softmax(logits)) β stochastic |
| if t = 0: x0_est = argmax(softmax(logits)) β deterministic |
| hint = x0_est |
| """ |
| |
| temperature = temperature if temperature is not None else self.temperature |
| top_k = top_k if top_k is not None else self.top_k |
| repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty |
| diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty |
|
|
| if num_steps is None: |
| num_steps = self.scheduler.num_timesteps |
|
|
| device = condition.device |
| if condition.dim() == 1: |
| condition = condition.unsqueeze(0) |
| B, L = condition.shape |
|
|
| T = self.scheduler.num_timesteps |
| step_size = max(1, T // num_steps) |
| timesteps = list(range(T - 1, -1, -step_size)) |
| if timesteps[-1] != 0: |
| timesteps.append(0) |
|
|
| mask_id = model.mask_token_id |
| x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device) |
| hint = None |
|
|
| model.eval() |
| with torch.no_grad(): |
| for step_idx, t_val in enumerate(timesteps): |
| t = torch.full((B,), t_val, dtype=torch.long, device=device) |
| is_last = (step_idx == len(timesteps) - 1) |
|
|
| |
| |
| |
| |
| |
| logits, _ = model(condition, x0_est, t, x0_hint=hint) |
|
|
| |
| if repetition_penalty != 1.0: |
| logits = apply_repetition_penalty( |
| logits, x0_est, repetition_penalty |
| ) |
|
|
| |
| if diversity_penalty > 0.0: |
| logits = apply_diversity_penalty(logits, diversity_penalty) |
|
|
| logits = logits / max(temperature, 1e-5) |
|
|
| if top_k > 0: |
| logits = top_k_filter(logits, top_k) |
|
|
| probs = F.softmax(logits, dim=-1) |
|
|
| |
| if is_last: |
| x0_est = torch.argmax(probs, dim=-1) |
| else: |
| x0_est = batch_multinomial(probs) |
|
|
| hint = x0_est |
|
|
| return x0_est |
|
|
| |
| |
| |
| def p_sample_step( |
| self, |
| model, |
| x_t, |
| t, |
| condition, |
| beam_width = 3, |
| temperature = 1.0, |
| repetition_penalty = 1.2, |
| diversity_penalty = 0.3, |
| x0_hint = None, |
| ): |
| with torch.no_grad(): |
| if x_t.dim() == 1: x_t = x_t.unsqueeze(0) |
| if condition.dim() == 1: condition = condition.unsqueeze(0) |
| if t.dim() == 0: t = t.unsqueeze(0) |
| if t.shape[0] != x_t.shape[0]: |
| t = t.expand(x_t.shape[0]) |
|
|
| |
| logits, _ = model(condition, x_t, t, x0_hint=x0_hint) |
|
|
| logits = logits / max(temperature, 1e-5) |
|
|
| if repetition_penalty != 1.0: |
| logits = apply_repetition_penalty(logits, x_t, repetition_penalty) |
| if diversity_penalty > 0.0: |
| logits = apply_diversity_penalty(logits, diversity_penalty) |
|
|
| probs = F.softmax(logits, dim=-1) |
| B, L, V = probs.shape |
|
|
| topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1) |
| candidates = [] |
| for k in range(beam_width): |
| next_tokens = topk_ids[:, :, k] |
| score = torch.log(topk_probs[:, :, k] + 1e-9).sum() |
| candidates.append((next_tokens, score)) |
| return candidates |
|
|
| |
| |
| |
| def generate_beam( |
| self, |
| model, |
| condition, |
| beam_width = 3, |
| num_steps = None, |
| temperature = None, |
| repetition_penalty = None, |
| diversity_penalty = None, |
| ): |
| """ |
| WARNING: do NOT call this from app1.py for D3PM generation. |
| generate_beam() forces every position to the same top-k token |
| β all-αΉ / all-rud repetition. Use generate() instead. |
| Kept only for experimental reference. |
| """ |
| temperature = temperature if temperature is not None else self.temperature |
| repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty |
| diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty |
| if num_steps is None: |
| num_steps = self.scheduler.num_timesteps |
|
|
| device = condition.device |
| if condition.dim() == 1: condition = condition.unsqueeze(0) |
| B, L = condition.shape |
|
|
| x_init = torch.full((B, L), fill_value=model.mask_token_id, |
| dtype=torch.long, device=device) |
| beams = [(x_init, 0.0)] |
| best_hint = None |
|
|
| for step in reversed(range(num_steps)): |
| t_tensor = torch.full((B,), step, dtype=torch.long, device=device) |
| new_beams = [] |
| for x_t, score in beams: |
| candidates = self.p_sample_step( |
| model, x_t, t_tensor, condition, |
| beam_width = beam_width, |
| temperature = temperature, |
| repetition_penalty = repetition_penalty, |
| diversity_penalty = diversity_penalty, |
| x0_hint = best_hint, |
| ) |
| for tokens, new_score in candidates: |
| new_beams.append((tokens, score + new_score.item())) |
|
|
| new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True) |
| beams = new_beams[:beam_width] |
| best_hint = beams[0][0] |
|
|
| return beams[0][0] |
|
|
|
|
| |
|
|
| def apply_repetition_penalty(logits, prev_tokens, penalty=1.2): |
| """Down-weight tokens already present in the sequence.""" |
| for b in range(logits.shape[0]): |
| for token_id in set(prev_tokens[b].tolist()): |
| if token_id > 4: |
| logits[b, :, token_id] = logits[b, :, token_id] / penalty |
| return logits |
|
|
|
|
| def apply_diversity_penalty(logits, penalty=0.3): |
| """ |
| Correct diversity penalty: suppress globally dominant tokens. |
| logits -= penalty * mean(logits, dim=1) [sequence dimension] |
| """ |
| global_mean = logits.mean(dim=1, keepdim=True) |
| return logits - penalty * global_mean |
|
|
|
|
| def top_k_filter(logits, k): |
| B, L, V = logits.shape |
| if k >= V: return logits |
| topk_vals, _ = torch.topk(logits, k, dim=-1) |
| return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf')) |
|
|
|
|
| def batch_multinomial(probs): |
| B, L, V = probs.shape |
| flat = probs.view(B * L, V) + 1e-9 |
| flat = flat / flat.sum(dim=-1, keepdim=True) |
| return torch.multinomial(flat, 1).squeeze(-1).view(B, L) |