Spaces:
Running
Running
| """ | |
| reverse_process.py β Final Correct Version | |
| ============================================= | |
| KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference() | |
| in inference.py, which is what produced BERTScore 0.75 at validation. | |
| CRITICAL BUG IN PREVIOUS VERSION: | |
| We passed inference_mode=True to the model, but the model was NEVER | |
| called with inference_mode=True during training or validation. | |
| run_inference() (the validated path) does: | |
| model(input_ids, x0_est, t, x0_hint=hint) | |
| β inference_mode defaults to False. | |
| With inference_mode=True the model does two things differently: | |
| 1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD) | |
| 2. Skips q_sample at t=0 (training always called q_sample) | |
| The model was never trained to handle these conditions β garbage output. | |
| Fix: do NOT pass inference_mode. Let it default to False, exactly | |
| as run_inference() did. | |
| BUGS FIXED (vs original reverse_process.py) | |
| -------------------------------------------- | |
| BUG 1 generate_beam() used for D3PM β all-αΉ repetition. | |
| Use generate() (iterative refinement) from app1.py instead. | |
| BUG 2 apply_diversity_penalty used logits.var() β noise injection. | |
| Fixed to logits - penalty * logits.mean(dim=1) β global suppression. | |
| BUG 3 x0_hint (self-conditioning) never passed to model. | |
| Fixed: generate() passes x0_hint=hint every step. | |
| BUG 4 params not forwarded from generate_beam() to p_sample_step(). | |
| Fixed in generate_beam() (kept for reference, not for production use). | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| class ReverseDiffusion: | |
| def __init__(self, scheduler): | |
| self.scheduler = scheduler | |
| # Attribute-style defaults for backward compat with any code | |
| # that sets reverse_diffusion.temperature = 0.9 etc. | |
| # generate() prefers explicit kwargs and falls back to these. | |
| self.temperature = 0.75 | |
| self.repetition_penalty = 1.15 | |
| self.diversity_penalty = 0.0 | |
| self.top_k = 50 | |
| # ------------------------------------------------------------------ # | |
| # generate β CORRECT D3PM iterative refinement # | |
| # Exact equivalent of run_inference() in inference.py # | |
| # ------------------------------------------------------------------ # | |
| def generate( | |
| self, | |
| model, | |
| condition, | |
| num_steps = None, | |
| temperature = None, | |
| top_k = None, | |
| repetition_penalty = None, | |
| diversity_penalty = None, | |
| ): | |
| """ | |
| D3PM iterative refinement β identical to run_inference() in inference.py, | |
| which is the validated path (BERTScore 0.75). | |
| Algorithm: | |
| x0_est = all [MASK] | |
| for t = T-1 down to 0: | |
| logits = model(src, x0_est, t, x0_hint=hint) | |
| β inference_mode NOT passed (defaults to False) | |
| β this exactly matches training/validation | |
| apply penalties, temperature, top_k | |
| if t > 0: x0_est = multinomial(softmax(logits)) β stochastic | |
| if t = 0: x0_est = argmax(softmax(logits)) β deterministic | |
| hint = x0_est | |
| """ | |
| # Resolve: explicit kwarg > object attribute | |
| temperature = temperature if temperature is not None else self.temperature | |
| top_k = top_k if top_k is not None else self.top_k | |
| repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty | |
| diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty | |
| if num_steps is None: | |
| num_steps = self.scheduler.num_timesteps | |
| device = condition.device | |
| if condition.dim() == 1: | |
| condition = condition.unsqueeze(0) | |
| B, L = condition.shape | |
| T = self.scheduler.num_timesteps | |
| step_size = max(1, T // num_steps) | |
| timesteps = list(range(T - 1, -1, -step_size)) | |
| if timesteps[-1] != 0: | |
| timesteps.append(0) | |
| mask_id = model.mask_token_id | |
| x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device) | |
| hint = None | |
| model.eval() | |
| with torch.no_grad(): | |
| for step_idx, t_val in enumerate(timesteps): | |
| t = torch.full((B,), t_val, dtype=torch.long, device=device) | |
| is_last = (step_idx == len(timesteps) - 1) | |
| # ββ CRITICAL: do NOT pass inference_mode ββββββββββββββββββ | |
| # inference_mode defaults to False inside SanskritModel / | |
| # D3PMCrossAttention. This matches run_inference() exactly. | |
| # Passing inference_mode=True changes tgt_pad_mask and | |
| # q_sample behaviour β the model was never trained for that. | |
| logits, _ = model(condition, x0_est, t, x0_hint=hint) | |
| # Repetition penalty | |
| if repetition_penalty != 1.0: | |
| logits = apply_repetition_penalty( | |
| logits, x0_est, repetition_penalty | |
| ) | |
| # Diversity penalty (correct: global mean suppression) | |
| if diversity_penalty > 0.0: | |
| logits = apply_diversity_penalty(logits, diversity_penalty) | |
| logits = logits / max(temperature, 1e-5) | |
| if top_k > 0: | |
| logits = top_k_filter(logits, top_k) | |
| probs = F.softmax(logits, dim=-1) | |
| # Stochastic at every step except the last (argmax at t=0) | |
| if is_last: | |
| x0_est = torch.argmax(probs, dim=-1) | |
| else: | |
| x0_est = batch_multinomial(probs) | |
| hint = x0_est | |
| return x0_est # (B, L) | |
| # ------------------------------------------------------------------ # | |
| # p_sample_step β used by generate_beam (not for production) # | |
| # ------------------------------------------------------------------ # | |
| def p_sample_step( | |
| self, | |
| model, | |
| x_t, | |
| t, | |
| condition, | |
| beam_width = 3, | |
| temperature = 1.0, | |
| repetition_penalty = 1.2, | |
| diversity_penalty = 0.3, | |
| x0_hint = None, | |
| ): | |
| with torch.no_grad(): | |
| if x_t.dim() == 1: x_t = x_t.unsqueeze(0) | |
| if condition.dim() == 1: condition = condition.unsqueeze(0) | |
| if t.dim() == 0: t = t.unsqueeze(0) | |
| if t.shape[0] != x_t.shape[0]: | |
| t = t.expand(x_t.shape[0]) | |
| # No inference_mode β matches training convention | |
| logits, _ = model(condition, x_t, t, x0_hint=x0_hint) | |
| logits = logits / max(temperature, 1e-5) | |
| if repetition_penalty != 1.0: | |
| logits = apply_repetition_penalty(logits, x_t, repetition_penalty) | |
| if diversity_penalty > 0.0: | |
| logits = apply_diversity_penalty(logits, diversity_penalty) | |
| probs = F.softmax(logits, dim=-1) | |
| B, L, V = probs.shape | |
| topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1) | |
| candidates = [] | |
| for k in range(beam_width): | |
| next_tokens = topk_ids[:, :, k] | |
| score = torch.log(topk_probs[:, :, k] + 1e-9).sum() | |
| candidates.append((next_tokens, score)) | |
| return candidates | |
| # ------------------------------------------------------------------ # | |
| # generate_beam β kept for reference; NOT the correct D3PM method # | |
| # ------------------------------------------------------------------ # | |
| def generate_beam( | |
| self, | |
| model, | |
| condition, | |
| beam_width = 3, | |
| num_steps = None, | |
| temperature = None, | |
| repetition_penalty = None, | |
| diversity_penalty = None, | |
| ): | |
| """ | |
| WARNING: do NOT call this from app1.py for D3PM generation. | |
| generate_beam() forces every position to the same top-k token | |
| β all-αΉ / all-rud repetition. Use generate() instead. | |
| Kept only for experimental reference. | |
| """ | |
| temperature = temperature if temperature is not None else self.temperature | |
| repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty | |
| diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty | |
| if num_steps is None: | |
| num_steps = self.scheduler.num_timesteps | |
| device = condition.device | |
| if condition.dim() == 1: condition = condition.unsqueeze(0) | |
| B, L = condition.shape | |
| x_init = torch.full((B, L), fill_value=model.mask_token_id, | |
| dtype=torch.long, device=device) | |
| beams = [(x_init, 0.0)] | |
| best_hint = None | |
| for step in reversed(range(num_steps)): | |
| t_tensor = torch.full((B,), step, dtype=torch.long, device=device) | |
| new_beams = [] | |
| for x_t, score in beams: | |
| candidates = self.p_sample_step( | |
| model, x_t, t_tensor, condition, | |
| beam_width = beam_width, | |
| temperature = temperature, | |
| repetition_penalty = repetition_penalty, | |
| diversity_penalty = diversity_penalty, | |
| x0_hint = best_hint, | |
| ) | |
| for tokens, new_score in candidates: | |
| new_beams.append((tokens, score + new_score.item())) | |
| new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True) | |
| beams = new_beams[:beam_width] | |
| best_hint = beams[0][0] | |
| return beams[0][0] # (B, L) | |
| # ββ Penalty helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def apply_repetition_penalty(logits, prev_tokens, penalty=1.2): | |
| """Down-weight tokens already present in the sequence.""" | |
| for b in range(logits.shape[0]): | |
| for token_id in set(prev_tokens[b].tolist()): | |
| if token_id > 4: | |
| logits[b, :, token_id] = logits[b, :, token_id] / penalty | |
| return logits | |
| def apply_diversity_penalty(logits, penalty=0.3): | |
| """ | |
| Correct diversity penalty: suppress globally dominant tokens. | |
| logits -= penalty * mean(logits, dim=1) [sequence dimension] | |
| """ | |
| global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V] | |
| return logits - penalty * global_mean | |
| def top_k_filter(logits, k): | |
| B, L, V = logits.shape | |
| if k >= V: return logits | |
| topk_vals, _ = torch.topk(logits, k, dim=-1) | |
| return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf')) | |
| def batch_multinomial(probs): | |
| B, L, V = probs.shape | |
| flat = probs.view(B * L, V) + 1e-9 | |
| flat = flat / flat.sum(dim=-1, keepdim=True) | |
| return torch.multinomial(flat, 1).squeeze(-1).view(B, L) |