DevaFlow-space / diffusion /reverse_process2.py
bhsinghgrid's picture
Upload Gradio Space app
1cc095d verified
"""
reverse_process.py β€” Final Correct Version
=============================================
KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
in inference.py, which is what produced BERTScore 0.75 at validation.
CRITICAL BUG IN PREVIOUS VERSION:
We passed inference_mode=True to the model, but the model was NEVER
called with inference_mode=True during training or validation.
run_inference() (the validated path) does:
model(input_ids, x0_est, t, x0_hint=hint)
β†’ inference_mode defaults to False.
With inference_mode=True the model does two things differently:
1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD)
2. Skips q_sample at t=0 (training always called q_sample)
The model was never trained to handle these conditions β†’ garbage output.
Fix: do NOT pass inference_mode. Let it default to False, exactly
as run_inference() did.
BUGS FIXED (vs original reverse_process.py)
--------------------------------------------
BUG 1 generate_beam() used for D3PM β†’ all-Ṛ repetition.
Use generate() (iterative refinement) from app1.py instead.
BUG 2 apply_diversity_penalty used logits.var() β†’ noise injection.
Fixed to logits - penalty * logits.mean(dim=1) β€” global suppression.
BUG 3 x0_hint (self-conditioning) never passed to model.
Fixed: generate() passes x0_hint=hint every step.
BUG 4 params not forwarded from generate_beam() to p_sample_step().
Fixed in generate_beam() (kept for reference, not for production use).
"""
import torch
import torch.nn.functional as F
class ReverseDiffusion:
def __init__(self, scheduler):
self.scheduler = scheduler
# Attribute-style defaults for backward compat with any code
# that sets reverse_diffusion.temperature = 0.9 etc.
# generate() prefers explicit kwargs and falls back to these.
self.temperature = 0.75
self.repetition_penalty = 1.15
self.diversity_penalty = 0.0
self.top_k = 50
# ------------------------------------------------------------------ #
# generate β€” CORRECT D3PM iterative refinement #
# Exact equivalent of run_inference() in inference.py #
# ------------------------------------------------------------------ #
def generate(
self,
model,
condition,
num_steps = None,
temperature = None,
top_k = None,
repetition_penalty = None,
diversity_penalty = None,
):
"""
D3PM iterative refinement β€” identical to run_inference() in inference.py,
which is the validated path (BERTScore 0.75).
Algorithm:
x0_est = all [MASK]
for t = T-1 down to 0:
logits = model(src, x0_est, t, x0_hint=hint)
↑ inference_mode NOT passed (defaults to False)
↑ this exactly matches training/validation
apply penalties, temperature, top_k
if t > 0: x0_est = multinomial(softmax(logits)) ← stochastic
if t = 0: x0_est = argmax(softmax(logits)) ← deterministic
hint = x0_est
"""
# Resolve: explicit kwarg > object attribute
temperature = temperature if temperature is not None else self.temperature
top_k = top_k if top_k is not None else self.top_k
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
if num_steps is None:
num_steps = self.scheduler.num_timesteps
device = condition.device
if condition.dim() == 1:
condition = condition.unsqueeze(0)
B, L = condition.shape
T = self.scheduler.num_timesteps
step_size = max(1, T // num_steps)
timesteps = list(range(T - 1, -1, -step_size))
if timesteps[-1] != 0:
timesteps.append(0)
mask_id = model.mask_token_id
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
hint = None
model.eval()
with torch.no_grad():
for step_idx, t_val in enumerate(timesteps):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = (step_idx == len(timesteps) - 1)
# ── CRITICAL: do NOT pass inference_mode ──────────────────
# inference_mode defaults to False inside SanskritModel /
# D3PMCrossAttention. This matches run_inference() exactly.
# Passing inference_mode=True changes tgt_pad_mask and
# q_sample behaviour β€” the model was never trained for that.
logits, _ = model(condition, x0_est, t, x0_hint=hint)
# Repetition penalty
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(
logits, x0_est, repetition_penalty
)
# Diversity penalty (correct: global mean suppression)
if diversity_penalty > 0.0:
logits = apply_diversity_penalty(logits, diversity_penalty)
logits = logits / max(temperature, 1e-5)
if top_k > 0:
logits = top_k_filter(logits, top_k)
probs = F.softmax(logits, dim=-1)
# Stochastic at every step except the last (argmax at t=0)
if is_last:
x0_est = torch.argmax(probs, dim=-1)
else:
x0_est = batch_multinomial(probs)
hint = x0_est
return x0_est # (B, L)
# ------------------------------------------------------------------ #
# p_sample_step β€” used by generate_beam (not for production) #
# ------------------------------------------------------------------ #
def p_sample_step(
self,
model,
x_t,
t,
condition,
beam_width = 3,
temperature = 1.0,
repetition_penalty = 1.2,
diversity_penalty = 0.3,
x0_hint = None,
):
with torch.no_grad():
if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
if condition.dim() == 1: condition = condition.unsqueeze(0)
if t.dim() == 0: t = t.unsqueeze(0)
if t.shape[0] != x_t.shape[0]:
t = t.expand(x_t.shape[0])
# No inference_mode β€” matches training convention
logits, _ = model(condition, x_t, t, x0_hint=x0_hint)
logits = logits / max(temperature, 1e-5)
if repetition_penalty != 1.0:
logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
if diversity_penalty > 0.0:
logits = apply_diversity_penalty(logits, diversity_penalty)
probs = F.softmax(logits, dim=-1)
B, L, V = probs.shape
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
candidates = []
for k in range(beam_width):
next_tokens = topk_ids[:, :, k]
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
candidates.append((next_tokens, score))
return candidates
# ------------------------------------------------------------------ #
# generate_beam β€” kept for reference; NOT the correct D3PM method #
# ------------------------------------------------------------------ #
def generate_beam(
self,
model,
condition,
beam_width = 3,
num_steps = None,
temperature = None,
repetition_penalty = None,
diversity_penalty = None,
):
"""
WARNING: do NOT call this from app1.py for D3PM generation.
generate_beam() forces every position to the same top-k token
β†’ all-Ṛ / all-rud repetition. Use generate() instead.
Kept only for experimental reference.
"""
temperature = temperature if temperature is not None else self.temperature
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
if num_steps is None:
num_steps = self.scheduler.num_timesteps
device = condition.device
if condition.dim() == 1: condition = condition.unsqueeze(0)
B, L = condition.shape
x_init = torch.full((B, L), fill_value=model.mask_token_id,
dtype=torch.long, device=device)
beams = [(x_init, 0.0)]
best_hint = None
for step in reversed(range(num_steps)):
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
new_beams = []
for x_t, score in beams:
candidates = self.p_sample_step(
model, x_t, t_tensor, condition,
beam_width = beam_width,
temperature = temperature,
repetition_penalty = repetition_penalty,
diversity_penalty = diversity_penalty,
x0_hint = best_hint,
)
for tokens, new_score in candidates:
new_beams.append((tokens, score + new_score.item()))
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
beams = new_beams[:beam_width]
best_hint = beams[0][0]
return beams[0][0] # (B, L)
# ── Penalty helpers ────────────────────────────────────────────────────────
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
"""Down-weight tokens already present in the sequence."""
for b in range(logits.shape[0]):
for token_id in set(prev_tokens[b].tolist()):
if token_id > 4:
logits[b, :, token_id] = logits[b, :, token_id] / penalty
return logits
def apply_diversity_penalty(logits, penalty=0.3):
"""
Correct diversity penalty: suppress globally dominant tokens.
logits -= penalty * mean(logits, dim=1) [sequence dimension]
"""
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
return logits - penalty * global_mean
def top_k_filter(logits, k):
B, L, V = logits.shape
if k >= V: return logits
topk_vals, _ = torch.topk(logits, k, dim=-1)
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
def batch_multinomial(probs):
B, L, V = probs.shape
flat = probs.view(B * L, V) + 1e-9
flat = flat / flat.sum(dim=-1, keepdim=True)
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)