File size: 9,309 Bytes
f8437ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 | """
reverse_process.py โ Fixed
===========================
Two bugs fixed from the original:
BUG 1 (critical): generate_beam() passed x_t (noisy) as `tgt` to model.
The model does q_sample(tgt, t) internally โ so x_t got double-noised.
Fix: pass x0_estimate (current clean guess) as tgt. Model noises it correctly.
BUG 2: apply_diversity_penalty used logits.var(dim=-1) โ this adds the
variance of each position's own distribution back to itself, which is
mathematically meaningless and just injects noise.
Fix: penalize tokens that are uniformly high-probability across ALL positions
(global common tokens). This genuinely promotes diversity.
"""
import torch
import torch.nn.functional as F
class ReverseDiffusion:
def __init__(self, scheduler):
self.scheduler = scheduler
def p_sample_step(
self,
model,
x_t,
t,
condition,
beam_width=3,
temperature=1.0,
repetition_penalty=1.2,
diversity_penalty=0.3
):
"""
Single reverse step with temperature + penalties.
"""
with torch.no_grad():
# ---- Shape safety ----
if x_t.dim() == 1:
x_t = x_t.unsqueeze(0)
if condition.dim() == 1:
condition = condition.unsqueeze(0)
if t.dim() == 0:
t = t.unsqueeze(0)
if t.shape[0] != x_t.shape[0]:
t = t.expand(x_t.shape[0])
# ---- Model forward ----
logits, _ = model(condition, x_t, t)
# ---- Temperature scaling ----
logits = logits / temperature
# ---- Repetition penalty (FIXED VERSION) ----
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(
logits, x_t, repetition_penalty
)
# ---- Diversity penalty ----
if diversity_penalty > 0:
logits = apply_diversity_penalty(
logits, diversity_penalty
)
probs = F.softmax(logits, dim=-1)
B, L, V = probs.shape
# ---- Top-k beam expansion ----
topk_probs, topk_ids = torch.topk(
probs, beam_width, dim=-1
)
candidates = []
for k in range(beam_width):
next_tokens = topk_ids[:, :, k]
score = torch.log(
topk_probs[:, :, k] + 1e-9
).sum()
candidates.append((next_tokens, score))
return candidates
def generate_beam(
self,
model,
condition,
beam_width=3,
num_steps=None,
temperature=1.0,
repetition_penalty=1.2,
diversity_penalty=0.3
):
"""
Beam-search reverse diffusion with temperature.
"""
if num_steps is None:
num_steps = self.scheduler.num_timesteps
device = condition.device
if condition.dim() == 1:
condition = condition.unsqueeze(0)
B, L = condition.shape
# ๐ฅ Better initialization: start from MASK
x_init = torch.full(
(B, L),
fill_value=model.mask_token_id,
dtype=torch.long,
device=device
)
beams = [(x_init, 0.0)]
for step in reversed(range(num_steps)):
new_beams = []
for x_t, score in beams:
t_tensor = torch.full(
(B,),
step,
dtype=torch.long,
device=device
)
candidates = self.p_sample_step(
model,
x_t,
t_tensor,
condition,
beam_width,
temperature,
repetition_penalty,
diversity_penalty
)
for tokens, new_score in candidates:
new_beams.append(
(tokens, score + new_score)
)
# ---- Keep top beams ----
new_beams = sorted(
new_beams,
key=lambda x: x[1],
reverse=True
)
beams = new_beams[:beam_width]
best_tokens, best_score = beams[0]
return best_tokens
def generate(
self,
model,
condition,
num_steps=None,
temperature=0.8,
top_k=50,
repetition_penalty=1.2,
diversity_penalty=0.0,
):
"""
Correct D3PM iterative refinement.
x0_est starts as all [MASK].
Each step: forward(src=condition, tgt=x0_est, t)
โ model applies q_sample(x0_est, t) internally
โ predicts cleaner x0
โ x0_est updated
diversity_penalty: reduces probability of tokens that are
globally dominant across all sequence positions (not logits.var()).
"""
if num_steps is None:
num_steps = self.scheduler.num_timesteps
device = condition.device
if condition.dim() == 1:
condition = condition.unsqueeze(0)
B, L = condition.shape
T = self.scheduler.num_timesteps
step_size = max(1, T // num_steps)
timesteps = list(range(T - 1, -1, -step_size))
if timesteps[-1] != 0:
timesteps.append(0)
mask_id = model.mask_token_id
# Start: know nothing โ all MASK is our initial clean estimate
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
hint = None
model.eval()
with torch.no_grad():
for step_idx, t_val in enumerate(timesteps):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = (step_idx == len(timesteps) - 1)
# KEY: pass x0_est as tgt โ model noises it internally
import inspect
sig = inspect.signature(model.forward).parameters
if 'x0_hint' in sig:
outputs = model(condition, x0_est, t, x0_hint=hint)
else:
outputs = model(condition, x0_est, t)
logits = outputs[0] if isinstance(outputs, tuple) else outputs
# Repetition penalty: down-weight tokens already in sequence
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(logits, x0_est, repetition_penalty)
# Diversity penalty: reduce globally dominant tokens
if diversity_penalty > 0.0:
logits = apply_diversity_penalty(logits, diversity_penalty)
# Temperature + top-k
logits = logits / max(temperature, 1e-5)
if top_k > 0:
logits = top_k_filter(logits, top_k)
probs = F.softmax(logits, dim=-1)
if is_last:
x0_est = torch.argmax(probs, dim=-1)
else:
x0_est = batch_multinomial(probs)
hint = x0_est
return x0_est
# โโ Penalty functions โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
"""
Down-weight tokens that already appear in the current sequence.
Prevents เคฎเคจเฅ เคฎเคจเฅ เคฎเคจเฅ repetition loops.
penalty=1.0 โ no effect
penalty=1.2 โ mild suppression of repeated tokens
penalty=2.0 โ strong suppression
"""
B, L, V = logits.shape
for b in range(B):
for token_id in set(prev_tokens[b].tolist()):
if token_id > 4: # don't penalize special tokens
logits[b, :, token_id] = logits[b, :, token_id] / penalty
return logits
def apply_diversity_penalty(logits, penalty=0.5):
"""
Correct diversity penalty: penalize tokens that are globally dominant
across ALL sequence positions. This forces the model to use less
common tokens, increasing output diversity.
Method: compute mean probability across positions, subtract penalty
times that mean. Tokens uniformly high everywhere get suppressed.
penalty=0.0 โ no diversity enforcement
penalty=0.5 โ moderate diversity
penalty=1.0 โ strong diversity (may hurt coherence)
"""
# Mean logit across all positions: [B, V]
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
# Subtract scaled global mean โ suppresses globally common tokens
return logits - penalty * global_mean
def top_k_filter(logits, k):
B, L, V = logits.shape
if k >= V:
return logits
topk_vals, _ = torch.topk(logits, k, dim=-1)
threshold = topk_vals[..., -1].unsqueeze(-1)
return logits.masked_fill(logits < threshold, float('-inf'))
def batch_multinomial(probs):
B, L, V = probs.shape
flat = probs.view(B * L, V) + 1e-9
flat = flat / flat.sum(dim=-1, keepdim=True)
return torch.multinomial(flat, 1).squeeze(-1).view(B, L) |