File size: 11,337 Bytes
7d6a683 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 | """
reverse_process.py β Final Correct Version
=============================================
KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
in inference.py, which is what produced BERTScore 0.75 at validation.
CRITICAL BUG IN PREVIOUS VERSION:
We passed inference_mode=True to the model, but the model was NEVER
called with inference_mode=True during training or validation.
run_inference() (the validated path) does:
model(input_ids, x0_est, t, x0_hint=hint)
β inference_mode defaults to False.
With inference_mode=True the model does two things differently:
1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD)
2. Skips q_sample at t=0 (training always called q_sample)
The model was never trained to handle these conditions β garbage output.
Fix: do NOT pass inference_mode. Let it default to False, exactly
as run_inference() did.
BUGS FIXED (vs original reverse_process.py)
--------------------------------------------
BUG 1 generate_beam() used for D3PM β all-αΉ repetition.
Use generate() (iterative refinement) from app1.py instead.
BUG 2 apply_diversity_penalty used logits.var() β noise injection.
Fixed to logits - penalty * logits.mean(dim=1) β global suppression.
BUG 3 x0_hint (self-conditioning) never passed to model.
Fixed: generate() passes x0_hint=hint every step.
BUG 4 params not forwarded from generate_beam() to p_sample_step().
Fixed in generate_beam() (kept for reference, not for production use).
"""
import torch
import torch.nn.functional as F
class ReverseDiffusion:
def __init__(self, scheduler):
self.scheduler = scheduler
# Attribute-style defaults for backward compat with any code
# that sets reverse_diffusion.temperature = 0.9 etc.
# generate() prefers explicit kwargs and falls back to these.
self.temperature = 0.75
self.repetition_penalty = 1.15
self.diversity_penalty = 0.0
self.top_k = 50
# ------------------------------------------------------------------ #
# generate β CORRECT D3PM iterative refinement #
# Exact equivalent of run_inference() in inference.py #
# ------------------------------------------------------------------ #
def generate(
self,
model,
condition,
num_steps = None,
temperature = None,
top_k = None,
repetition_penalty = None,
diversity_penalty = None,
):
"""
D3PM iterative refinement β identical to run_inference() in inference.py,
which is the validated path (BERTScore 0.75).
Algorithm:
x0_est = all [MASK]
for t = T-1 down to 0:
logits = model(src, x0_est, t, x0_hint=hint)
β inference_mode NOT passed (defaults to False)
β this exactly matches training/validation
apply penalties, temperature, top_k
if t > 0: x0_est = multinomial(softmax(logits)) β stochastic
if t = 0: x0_est = argmax(softmax(logits)) β deterministic
hint = x0_est
"""
# Resolve: explicit kwarg > object attribute
temperature = temperature if temperature is not None else self.temperature
top_k = top_k if top_k is not None else self.top_k
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
if num_steps is None:
num_steps = self.scheduler.num_timesteps
device = condition.device
if condition.dim() == 1:
condition = condition.unsqueeze(0)
B, L = condition.shape
T = self.scheduler.num_timesteps
step_size = max(1, T // num_steps)
timesteps = list(range(T - 1, -1, -step_size))
if timesteps[-1] != 0:
timesteps.append(0)
mask_id = model.mask_token_id
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
hint = None
model.eval()
with torch.no_grad():
for step_idx, t_val in enumerate(timesteps):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = (step_idx == len(timesteps) - 1)
# ββ CRITICAL: do NOT pass inference_mode ββββββββββββββββββ
# inference_mode defaults to False inside SanskritModel /
# D3PMCrossAttention. This matches run_inference() exactly.
# Passing inference_mode=True changes tgt_pad_mask and
# q_sample behaviour β the model was never trained for that.
logits, _ = model(condition, x0_est, t, x0_hint=hint)
# Repetition penalty
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(
logits, x0_est, repetition_penalty
)
# Diversity penalty (correct: global mean suppression)
if diversity_penalty > 0.0:
logits = apply_diversity_penalty(logits, diversity_penalty)
logits = logits / max(temperature, 1e-5)
if top_k > 0:
logits = top_k_filter(logits, top_k)
probs = F.softmax(logits, dim=-1)
# Stochastic at every step except the last (argmax at t=0)
if is_last:
x0_est = torch.argmax(probs, dim=-1)
else:
x0_est = batch_multinomial(probs)
hint = x0_est
return x0_est # (B, L)
# ------------------------------------------------------------------ #
# p_sample_step β used by generate_beam (not for production) #
# ------------------------------------------------------------------ #
def p_sample_step(
self,
model,
x_t,
t,
condition,
beam_width = 3,
temperature = 1.0,
repetition_penalty = 1.2,
diversity_penalty = 0.3,
x0_hint = None,
):
with torch.no_grad():
if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
if condition.dim() == 1: condition = condition.unsqueeze(0)
if t.dim() == 0: t = t.unsqueeze(0)
if t.shape[0] != x_t.shape[0]:
t = t.expand(x_t.shape[0])
# No inference_mode β matches training convention
logits, _ = model(condition, x_t, t, x0_hint=x0_hint)
logits = logits / max(temperature, 1e-5)
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
if diversity_penalty > 0.0:
logits = apply_diversity_penalty(logits, diversity_penalty)
probs = F.softmax(logits, dim=-1)
B, L, V = probs.shape
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
candidates = []
for k in range(beam_width):
next_tokens = topk_ids[:, :, k]
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
candidates.append((next_tokens, score))
return candidates
# ------------------------------------------------------------------ #
# generate_beam β kept for reference; NOT the correct D3PM method #
# ------------------------------------------------------------------ #
def generate_beam(
self,
model,
condition,
beam_width = 3,
num_steps = None,
temperature = None,
repetition_penalty = None,
diversity_penalty = None,
):
"""
WARNING: do NOT call this from app1.py for D3PM generation.
generate_beam() forces every position to the same top-k token
β all-αΉ / all-rud repetition. Use generate() instead.
Kept only for experimental reference.
"""
temperature = temperature if temperature is not None else self.temperature
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
if num_steps is None:
num_steps = self.scheduler.num_timesteps
device = condition.device
if condition.dim() == 1: condition = condition.unsqueeze(0)
B, L = condition.shape
x_init = torch.full((B, L), fill_value=model.mask_token_id,
dtype=torch.long, device=device)
beams = [(x_init, 0.0)]
best_hint = None
for step in reversed(range(num_steps)):
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
new_beams = []
for x_t, score in beams:
candidates = self.p_sample_step(
model, x_t, t_tensor, condition,
beam_width = beam_width,
temperature = temperature,
repetition_penalty = repetition_penalty,
diversity_penalty = diversity_penalty,
x0_hint = best_hint,
)
for tokens, new_score in candidates:
new_beams.append((tokens, score + new_score.item()))
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_width]
best_hint = beams[0][0]
return beams[0][0] # (B, L)
# ββ Penalty helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
"""Down-weight tokens already present in the sequence."""
for b in range(logits.shape[0]):
for token_id in set(prev_tokens[b].tolist()):
if token_id > 4:
logits[b, :, token_id] = logits[b, :, token_id] / penalty
return logits
def apply_diversity_penalty(logits, penalty=0.3):
"""
Correct diversity penalty: suppress globally dominant tokens.
logits -= penalty * mean(logits, dim=1) [sequence dimension]
"""
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
return logits - penalty * global_mean
def top_k_filter(logits, k):
B, L, V = logits.shape
if k >= V: return logits
topk_vals, _ = torch.topk(logits, k, dim=-1)
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
def batch_multinomial(probs):
B, L, V = probs.shape
flat = probs.view(B * L, V) + 1e-9
flat = flat / flat.sum(dim=-1, keepdim=True)
return torch.multinomial(flat, 1).squeeze(-1).view(B, L) |