Spaces:
Running
Running
File size: 3,931 Bytes
27f26fd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import torch
import torch.nn.functional as F
class ReverseDiffusion:
"""
Stable reverse diffusion with:
- Beam search
- Self conditioning
- Temperature sampling
- Repetition penalty
- Diversity penalty
"""
def __init__(self, scheduler):
self.scheduler = scheduler
self.temperature = 0.75
self.repetition_penalty = 1.15
self.diversity_penalty = 0.0
self.length_penalty = 1.0
# ------------------------------------------------
# penalties
# ------------------------------------------------
def apply_repetition_penalty(self, logits, tokens):
B, L, V = logits.shape
for b in range(B):
used = set(tokens[b].tolist())
for token_id in used:
logits[b, :, token_id] /= self.repetition_penalty
return logits
def apply_diversity_penalty(self, logits):
if self.diversity_penalty == 0:
return logits
logits_var = logits.var(dim=-1, keepdim=True)
return logits + self.diversity_penalty * logits_var
# ------------------------------------------------
# single reverse step
# ------------------------------------------------
def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3):
with torch.no_grad():
logits, hidden = model(condition, x_t, t, self_cond)
logits = logits / self.temperature
logits = self.apply_repetition_penalty(logits, x_t)
logits = self.apply_diversity_penalty(logits)
probs = F.softmax(logits, dim=-1)
B, L, V = probs.shape
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
candidates = []
for k in range(beam_width):
tokens = topk_ids[:, :, k]
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
candidates.append((tokens, score))
return candidates
# ------------------------------------------------
# beam reverse diffusion
# ------------------------------------------------
def generate_beam(self, model, condition, beam_width=3, num_steps=None):
if num_steps is None:
num_steps = self.scheduler.num_timesteps
device = condition.device
if condition.dim() == 1:
condition = condition.unsqueeze(0)
B, L = condition.shape
# ------------------------------------------------
# BETTER LATENT INITIALIZATION
# ------------------------------------------------
x_init = condition.clone()
mask = torch.rand_like(x_init.float()) < 0.5
x_init[mask] = model.mask_token_id
beams = [(x_init, 0.0)]
self_cond = None
for step in reversed(range(num_steps)):
new_beams = []
for x_t, score in beams:
t_tensor = torch.full(
(B,),
step,
dtype=torch.long,
device=device
)
candidates = self.p_sample_step(
model,
x_t,
t_tensor,
condition,
self_cond,
beam_width
)
for tokens, new_score in candidates:
length_norm = tokens.shape[1] ** self.length_penalty
final_score = (score + new_score) / length_norm
new_beams.append((tokens, final_score))
new_beams = sorted(
new_beams,
key=lambda x: x[1],
reverse=True
)
beams = new_beams[:beam_width]
# self conditioning
self_cond = beams[0][0]
best_tokens, best_score = beams[0]
return best_tokens |