DevaFlow / diffusion /reverse_process.py
bhsinghgrid's picture
Add files using upload-large-folder tool
7d6a683 verified
"""
reverse_process.py โ€” Fixed
===========================
Two bugs fixed from the original:
BUG 1 (critical): generate_beam() passed x_t (noisy) as `tgt` to model.
The model does q_sample(tgt, t) internally โ€” so x_t got double-noised.
Fix: pass x0_estimate (current clean guess) as tgt. Model noises it correctly.
BUG 2: apply_diversity_penalty used logits.var(dim=-1) โ€” this adds the
variance of each position's own distribution back to itself, which is
mathematically meaningless and just injects noise.
Fix: penalize tokens that are uniformly high-probability across ALL positions
(global common tokens). This genuinely promotes diversity.
"""
import torch
import torch.nn.functional as F
class ReverseDiffusion:
def __init__(self, scheduler):
self.scheduler = scheduler
def p_sample_step(
self,
model,
x_t,
t,
condition,
beam_width=3,
temperature=1.0,
repetition_penalty=1.2,
diversity_penalty=0.3
):
"""
Single reverse step with temperature + penalties.
"""
with torch.no_grad():
# ---- Shape safety ----
if x_t.dim() == 1:
x_t = x_t.unsqueeze(0)
if condition.dim() == 1:
condition = condition.unsqueeze(0)
if t.dim() == 0:
t = t.unsqueeze(0)
if t.shape[0] != x_t.shape[0]:
t = t.expand(x_t.shape[0])
# ---- Model forward ----
logits, _ = model(condition, x_t, t)
# ---- Temperature scaling ----
logits = logits / temperature
# ---- Repetition penalty (FIXED VERSION) ----
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(
logits, x_t, repetition_penalty
)
# ---- Diversity penalty ----
if diversity_penalty > 0:
logits = apply_diversity_penalty(
logits, diversity_penalty
)
probs = F.softmax(logits, dim=-1)
B, L, V = probs.shape
# ---- Top-k beam expansion ----
topk_probs, topk_ids = torch.topk(
probs, beam_width, dim=-1
)
candidates = []
for k in range(beam_width):
next_tokens = topk_ids[:, :, k]
score = torch.log(
topk_probs[:, :, k] + 1e-9
).sum()
candidates.append((next_tokens, score))
return candidates
def generate_beam(
self,
model,
condition,
beam_width=3,
num_steps=None,
temperature=1.0,
repetition_penalty=1.2,
diversity_penalty=0.3
):
"""
Beam-search reverse diffusion with temperature.
"""
if num_steps is None:
num_steps = self.scheduler.num_timesteps
device = condition.device
if condition.dim() == 1:
condition = condition.unsqueeze(0)
B, L = condition.shape
# ๐Ÿ”ฅ Better initialization: start from MASK
x_init = torch.full(
(B, L),
fill_value=model.mask_token_id,
dtype=torch.long,
device=device
)
beams = [(x_init, 0.0)]
for step in reversed(range(num_steps)):
new_beams = []
for x_t, score in beams:
t_tensor = torch.full(
(B,),
step,
dtype=torch.long,
device=device
)
candidates = self.p_sample_step(
model,
x_t,
t_tensor,
condition,
beam_width,
temperature,
repetition_penalty,
diversity_penalty
)
for tokens, new_score in candidates:
new_beams.append(
(tokens, score + new_score)
)
# ---- Keep top beams ----
new_beams = sorted(
new_beams,
key=lambda x: x[1],
reverse=True
)
beams = new_beams[:beam_width]
best_tokens, best_score = beams[0]
return best_tokens
def generate(
self,
model,
condition,
num_steps=None,
temperature=0.8,
top_k=50,
repetition_penalty=1.2,
diversity_penalty=0.0,
):
"""
Correct D3PM iterative refinement.
x0_est starts as all [MASK].
Each step: forward(src=condition, tgt=x0_est, t)
โ†’ model applies q_sample(x0_est, t) internally
โ†’ predicts cleaner x0
โ†’ x0_est updated
diversity_penalty: reduces probability of tokens that are
globally dominant across all sequence positions (not logits.var()).
"""
if num_steps is None:
num_steps = self.scheduler.num_timesteps
device = condition.device
if condition.dim() == 1:
condition = condition.unsqueeze(0)
B, L = condition.shape
T = self.scheduler.num_timesteps
step_size = max(1, T // num_steps)
timesteps = list(range(T - 1, -1, -step_size))
if timesteps[-1] != 0:
timesteps.append(0)
mask_id = model.mask_token_id
# Start: know nothing โ†’ all MASK is our initial clean estimate
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
hint = None
model.eval()
with torch.no_grad():
for step_idx, t_val in enumerate(timesteps):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = (step_idx == len(timesteps) - 1)
# KEY: pass x0_est as tgt โ€” model noises it internally
import inspect
sig = inspect.signature(model.forward).parameters
if 'x0_hint' in sig:
outputs = model(condition, x0_est, t, x0_hint=hint)
else:
outputs = model(condition, x0_est, t)
logits = outputs[0] if isinstance(outputs, tuple) else outputs
# Repetition penalty: down-weight tokens already in sequence
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(logits, x0_est, repetition_penalty)
# Diversity penalty: reduce globally dominant tokens
if diversity_penalty > 0.0:
logits = apply_diversity_penalty(logits, diversity_penalty)
# Temperature + top-k
logits = logits / max(temperature, 1e-5)
if top_k > 0:
logits = top_k_filter(logits, top_k)
probs = F.softmax(logits, dim=-1)
if is_last:
x0_est = torch.argmax(probs, dim=-1)
else:
x0_est = batch_multinomial(probs)
hint = x0_est
return x0_est
# โ”€โ”€ Penalty functions โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
"""
Down-weight tokens that already appear in the current sequence.
Prevents เคฎเคจเฅ‹ เคฎเคจเฅ‹ เคฎเคจเฅ‹ repetition loops.
penalty=1.0 โ†’ no effect
penalty=1.2 โ†’ mild suppression of repeated tokens
penalty=2.0 โ†’ strong suppression
"""
B, L, V = logits.shape
for b in range(B):
for token_id in set(prev_tokens[b].tolist()):
if token_id > 4: # don't penalize special tokens
logits[b, :, token_id] = logits[b, :, token_id] / penalty
return logits
def apply_diversity_penalty(logits, penalty=0.5):
"""
Correct diversity penalty: penalize tokens that are globally dominant
across ALL sequence positions. This forces the model to use less
common tokens, increasing output diversity.
Method: compute mean probability across positions, subtract penalty
times that mean. Tokens uniformly high everywhere get suppressed.
penalty=0.0 โ†’ no diversity enforcement
penalty=0.5 โ†’ moderate diversity
penalty=1.0 โ†’ strong diversity (may hurt coherence)
"""
# Mean logit across all positions: [B, V]
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
# Subtract scaled global mean โ€” suppresses globally common tokens
return logits - penalty * global_mean
def top_k_filter(logits, k):
B, L, V = logits.shape
if k >= V:
return logits
topk_vals, _ = torch.topk(logits, k, dim=-1)
threshold = topk_vals[..., -1].unsqueeze(-1)
return logits.masked_fill(logits < threshold, float('-inf'))
def batch_multinomial(probs):
B, L, V = probs.shape
flat = probs.view(B * L, V) + 1e-9
flat = flat / flat.sum(dim=-1, keepdim=True)
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)