import torch import torch.nn.functional as F class ReverseDiffusion: """ Stable reverse diffusion with: - Beam search - Self conditioning - Temperature sampling - Repetition penalty - Diversity penalty """ def __init__(self, scheduler): self.scheduler = scheduler self.temperature = 0.75 self.repetition_penalty = 1.15 self.diversity_penalty = 0.0 self.length_penalty = 1.0 # ------------------------------------------------ # penalties # ------------------------------------------------ def apply_repetition_penalty(self, logits, tokens): B, L, V = logits.shape for b in range(B): used = set(tokens[b].tolist()) for token_id in used: logits[b, :, token_id] /= self.repetition_penalty return logits def apply_diversity_penalty(self, logits): if self.diversity_penalty == 0: return logits logits_var = logits.var(dim=-1, keepdim=True) return logits + self.diversity_penalty * logits_var # ------------------------------------------------ # single reverse step # ------------------------------------------------ def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3): with torch.no_grad(): logits, hidden = model(condition, x_t, t, self_cond) logits = logits / self.temperature logits = self.apply_repetition_penalty(logits, x_t) logits = self.apply_diversity_penalty(logits) probs = F.softmax(logits, dim=-1) B, L, V = probs.shape topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1) candidates = [] for k in range(beam_width): tokens = topk_ids[:, :, k] score = torch.log(topk_probs[:, :, k] + 1e-9).sum() candidates.append((tokens, score)) return candidates # ------------------------------------------------ # beam reverse diffusion # ------------------------------------------------ def generate_beam(self, model, condition, beam_width=3, num_steps=None): if num_steps is None: num_steps = self.scheduler.num_timesteps device = condition.device if condition.dim() == 1: condition = condition.unsqueeze(0) B, L = condition.shape # ------------------------------------------------ # BETTER LATENT INITIALIZATION # ------------------------------------------------ x_init = condition.clone() mask = torch.rand_like(x_init.float()) < 0.5 x_init[mask] = model.mask_token_id beams = [(x_init, 0.0)] self_cond = None for step in reversed(range(num_steps)): new_beams = [] for x_t, score in beams: t_tensor = torch.full( (B,), step, dtype=torch.long, device=device ) candidates = self.p_sample_step( model, x_t, t_tensor, condition, self_cond, beam_width ) for tokens, new_score in candidates: length_norm = tokens.shape[1] ** self.length_penalty final_score = (score + new_score) / length_norm new_beams.append((tokens, final_score)) new_beams = sorted( new_beams, key=lambda x: x[1], reverse=True ) beams = new_beams[:beam_width] # self conditioning self_cond = beams[0][0] best_tokens, best_score = beams[0] return best_tokens