| import torch |
| import torch.nn.functional as F |
|
|
|
|
| class ReverseDiffusion: |
| """ |
| Stable reverse diffusion with: |
| - Beam search |
| - Self conditioning |
| - Temperature sampling |
| - Repetition penalty |
| - Diversity penalty |
| """ |
|
|
| def __init__(self, scheduler): |
|
|
| self.scheduler = scheduler |
|
|
| self.temperature = 0.75 |
| self.repetition_penalty = 1.15 |
| self.diversity_penalty = 0.0 |
| self.length_penalty = 1.0 |
|
|
| |
| |
| |
|
|
| def apply_repetition_penalty(self, logits, tokens): |
|
|
| B, L, V = logits.shape |
|
|
| for b in range(B): |
|
|
| used = set(tokens[b].tolist()) |
|
|
| for token_id in used: |
| logits[b, :, token_id] /= self.repetition_penalty |
|
|
| return logits |
|
|
| def apply_diversity_penalty(self, logits): |
|
|
| if self.diversity_penalty == 0: |
| return logits |
|
|
| logits_var = logits.var(dim=-1, keepdim=True) |
| return logits + self.diversity_penalty * logits_var |
|
|
| |
| |
| |
|
|
| def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3): |
|
|
| with torch.no_grad(): |
|
|
| logits, hidden = model(condition, x_t, t, self_cond) |
|
|
| logits = logits / self.temperature |
|
|
| logits = self.apply_repetition_penalty(logits, x_t) |
| logits = self.apply_diversity_penalty(logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
|
|
| B, L, V = probs.shape |
|
|
| topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1) |
|
|
| candidates = [] |
|
|
| for k in range(beam_width): |
|
|
| tokens = topk_ids[:, :, k] |
|
|
| score = torch.log(topk_probs[:, :, k] + 1e-9).sum() |
|
|
| candidates.append((tokens, score)) |
|
|
| return candidates |
|
|
| |
| |
| |
|
|
| def generate_beam(self, model, condition, beam_width=3, num_steps=None): |
|
|
| if num_steps is None: |
| num_steps = self.scheduler.num_timesteps |
|
|
| device = condition.device |
|
|
| if condition.dim() == 1: |
| condition = condition.unsqueeze(0) |
|
|
| B, L = condition.shape |
|
|
| |
| |
| |
|
|
| x_init = condition.clone() |
|
|
| mask = torch.rand_like(x_init.float()) < 0.5 |
| x_init[mask] = model.mask_token_id |
|
|
| beams = [(x_init, 0.0)] |
|
|
| self_cond = None |
|
|
| for step in reversed(range(num_steps)): |
|
|
| new_beams = [] |
|
|
| for x_t, score in beams: |
|
|
| t_tensor = torch.full( |
| (B,), |
| step, |
| dtype=torch.long, |
| device=device |
| ) |
|
|
| candidates = self.p_sample_step( |
| model, |
| x_t, |
| t_tensor, |
| condition, |
| self_cond, |
| beam_width |
| ) |
|
|
| for tokens, new_score in candidates: |
|
|
| length_norm = tokens.shape[1] ** self.length_penalty |
|
|
| final_score = (score + new_score) / length_norm |
|
|
| new_beams.append((tokens, final_score)) |
|
|
| new_beams = sorted( |
| new_beams, |
| key=lambda x: x[1], |
| reverse=True |
| ) |
|
|
| beams = new_beams[:beam_width] |
|
|
| |
| self_cond = beams[0][0] |
|
|
| best_tokens, best_score = beams[0] |
|
|
| return best_tokens |