"""Audio-native DiffusionGemma: graft an encoder-free audio pathway onto the real `DiffusionGemmaForBlockDiffusion`, and train it with the model's *own* uniform discrete-diffusion objective. Verified mechanism (from transformers `models/diffusion_gemma`, 2026-06): * Encoder–decoder model. The ENCODER (causal) turns the prompt `input_ids` into a read-only KV cache; the DECODER refines a `decoder_input_ids` canvas with bidirectional self-attention + cross-attention to that cache. Encoder/decoder transformer weights are tied. * Multimodal inputs are injected by scattering projected features into the placeholder-token positions of the encoder's `inputs_embeds` (the vision path uses `image_token_id`; we add the audio analog at `AUDIO_TOKEN_ID`). * Generation is UNIFORM discrete diffusion: canvas starts as uniform-random tokens; each step accepts low-entropy predictions and *renoises the rest to fresh uniform-random tokens*. There is NO absorbing state. So training = denoising score matching against uniform corruption: take the clean transcript canvas x0, replace a fraction γ of positions with uniform-random tokens to get x_t, and train the model to predict x0 at the corrupted positions, conditioned on the audio (in the encoder cache). """ from __future__ import annotations from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from audio import AudioProjector @dataclass class AudioDiffusionConfig: model_dir: str whisper_id: str = "openai/whisper-small" whisper_dim: int = 768 d_model: int = 2816 vocab_size: int = 262144 boa_token_id: int = 256000 audio_token_id: int = 258881 eoa_token_id: int = 258883 final_logit_softcapping: float = 30.0 subsample_factor: int = 8 proj_hidden: int = 1280 class AudioDiffusionGemma(nn.Module): def __init__(self, base, cfg: AudioDiffusionConfig, whisper=None): super().__init__() self.base = base # DiffusionGemmaForBlockDiffusion self.whisper = whisper # frozen Whisper encoder (feature extractor) self.cfg = cfg self.projector = AudioProjector( d_model=cfg.d_model, in_dim=cfg.whisper_dim, hidden=cfg.proj_hidden, subsample_factor=cfg.subsample_factor, ) # Sub-module handles (avoid the vision tower entirely). self.text_encoder = base.model.encoder.language_model # DiffusionGemmaEncoderTextModel self.decoder = base.model.decoder # DiffusionGemmaDecoderModel self.embed = base.get_input_embeddings() # scaled word embedding (tied) self.lm_head = base.lm_head # ---- construction ---- @classmethod def from_pretrained(cls, cfg: AudioDiffusionConfig, dtype=torch.bfloat16, device="cuda"): import transformers from transformers import AutoConfig, WhisperModel hfcfg = AutoConfig.from_pretrained(cfg.model_dir) ModelClass = getattr(transformers, hfcfg.architectures[0]) base = ModelClass.from_pretrained(cfg.model_dir, dtype=dtype, device_map=device) # Frozen Whisper encoder (acoustic feature extractor; NOT a decoder). whisper = WhisperModel.from_pretrained(cfg.whisper_id, dtype=dtype).encoder whisper = whisper.to(device).eval() for p in whisper.parameters(): p.requires_grad_(False) model = cls(base, cfg, whisper=whisper) # Keep the trainable projector in fp32 for stable AdamW; backbone stays bf16. model.projector = model.projector.to(device=device, dtype=torch.float32) return model def freeze_backbone(self): """Stage-1: only the audio projector trains.""" for p in self.base.parameters(): p.requires_grad_(False) for p in self.projector.parameters(): p.requires_grad_(True) def apply_lora(self, r: int = 16, alpha: int = 32, dropout: float = 0.05): """LoRA on ENCODER + DECODER attention (so the encoder learns to integrate audio into the cache AND the decoder learns to attend to it), plus the trainable projector. Experts/vision stay frozen. """ from peft import LoraConfig, get_peft_model target = r".*(decoder|encoder\.language_model)\.layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)$" lcfg = LoraConfig(r=r, lora_alpha=alpha, lora_dropout=dropout, target_modules=target, bias="none") self.base = get_peft_model(self.base, lcfg) # freezes base, enables LoRA; submodule objects unchanged for p in self.projector.parameters(): p.requires_grad_(True) return self def lora_state_dict(self): from peft import get_peft_model_state_dict return get_peft_model_state_dict(self.base) def trainable_parameters(self): return [p for p in self.parameters() if p.requires_grad] # ---- audio injection into the encoder ---- def _project(self, mel): """Frozen Whisper -> trainable projector -> [B, T_a, H] audio embeds.""" wdtype = next(self.whisper.parameters()).dtype with torch.no_grad(): # encoder frozen -> no grad/activations wfeat = self.whisper(mel.to(wdtype)).last_hidden_state # [B, 1500, 768] return self.projector(wfeat.float()) # [B, T_a, H] (fp32 projector) def _scatter_audio(self, input_ids, audio): """Embed token ids, then overwrite AUDIO_TOKEN_ID slots with audio embeds.""" H = audio.shape[-1] flat = audio.reshape(-1, H) embeds = self.embed(input_ids) audio_pos = input_ids == self.cfg.audio_token_id assert int(audio_pos.sum()) == flat.shape[0], ( f"audio placeholders ({int(audio_pos.sum())}) != audio embeds ({flat.shape[0]})" ) return embeds.masked_scatter(audio_pos.unsqueeze(-1), flat.to(embeds.dtype)) def _encode(self, prompt_ids, prompt_mask, audio): """Scatter audio into the prompt and run the encoder -> KV cache.""" embeds = self._scatter_audio(prompt_ids, audio) enc = self.text_encoder(inputs_embeds=embeds, attention_mask=prompt_mask) return enc.past_key_values def _decode_logits(self, cache, canvas_ids, decoder_attention_mask, self_conditioning_logits=None): dec = self.decoder( decoder_input_ids=canvas_ids, past_key_values=cache, decoder_attention_mask=decoder_attention_mask, self_conditioning_logits=self_conditioning_logits, ) logits = self.lm_head(dec.last_hidden_state).float() sc = self.cfg.final_logit_softcapping logits = torch.tanh(logits / sc) * sc return logits def _softcap(self, logits): sc = self.cfg.final_logit_softcapping return torch.tanh(logits.float() / sc) * sc # ---- training step ---- def forward(self, batch, gamma_min: float = 0.0, high_gamma_frac: float = 0.0, ar_weight: float = 0.0, ctc_weight: float = 0.0, weight_by_gamma: bool = False): """Joint training step: diffusion loss + optional autoregressive aux loss. The diffusion objective is a weak teacher for audio conditioning (the frozen LM prior out-competes the diluted, whole-canvas gradient). The AR aux loss teacher-forces the transcript through DiffusionGemma's OWN AR encoder (strong per-token audio→text gradient); because the encoder shares weights with the diffusion decoder + the same projector, that grounding transfers to parallel denoising. `ar_weight` mixes them. """ prompt_ids = batch["prompt_ids"] prompt_mask = batch["prompt_mask"] mel = batch["mel"] x0 = batch["canvas"] # [B, L] clean target (tokens, EOS, PAD) loss_mask = batch["canvas_loss_mask"] B, L = x0.shape P = prompt_ids.shape[1] device = x0.device pad = 0 audio = self._project(mel) # [B, T_a, H] — shared by both branches # --- diffusion branch (uniform corruption q(x_t|x0)) --- gamma = torch.empty(B, 1, device=device).uniform_(gamma_min, 1.0) if high_gamma_frac > 0: force_full = torch.rand(B, 1, device=device) < high_gamma_frac gamma = torch.where(force_full, torch.ones_like(gamma), gamma) corrupt = torch.rand(B, L, device=device) < gamma rand_tok = torch.randint(0, self.cfg.vocab_size, (B, L), device=device) x_t = torch.where(corrupt, rand_tok, x0) cache = self._encode(prompt_ids, prompt_mask, audio) dec_mask = torch.cat([prompt_mask, torch.ones(B, L, device=device, dtype=prompt_mask.dtype)], dim=1) logits = self._decode_logits(cache, x_t, dec_mask) # [B, L, V] train_pos = corrupt & loss_mask.bool() if train_pos.sum() == 0: train_pos = loss_mask.bool() ce = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), x0.reshape(-1), reduction="none").reshape(B, L) if weight_by_gamma: ce = ce / gamma.clamp_min(1e-3) diff_loss = (ce * train_pos.float()).sum() / train_pos.float().sum().clamp_min(1.0) # --- AR auxiliary branch (teacher-forced transcript through the AR encoder) --- ar_loss = torch.zeros((), device=device) if ar_weight > 0: ar_ids = torch.cat([prompt_ids, x0], dim=1) # [B, P+L] ar_real = (x0 != pad) # transcript+EOS ar_attn = torch.cat([prompt_mask, ar_real.to(prompt_mask.dtype)], dim=1) ar_embeds = self._scatter_audio(ar_ids, audio) ar_hidden = self.text_encoder(inputs_embeds=ar_embeds, attention_mask=ar_attn).last_hidden_state ar_logits = self._softcap(self.lm_head(ar_hidden[:, P - 1:P + L - 1, :])) # predicts x0 ar_ce = F.cross_entropy(ar_logits.reshape(-1, ar_logits.shape[-1]), x0.reshape(-1), reduction="none").reshape(B, L) ar_loss = (ar_ce * ar_real.float()).sum() / ar_real.float().sum().clamp_min(1.0) # --- CTC auxiliary branch (direct projector supervision; breaks the # chicken-and-egg by making audio embeds transcript-predictive in the # LLM token space via the frozen lm_head, independent of attention) --- ctc_loss = torch.zeros((), device=device) if ctc_weight > 0: ctc_logits = self.lm_head(audio.to(self.lm_head.weight.dtype)) # [B, T_a, V] log_probs = ctc_logits.float().log_softmax(-1).transpose(0, 1) # [T_a, B, V] ctc_loss = F.ctc_loss( log_probs, batch["ctc_targets"], batch["audio_real_lengths"], batch["ctc_target_lengths"], blank=0, zero_infinity=True) loss = diff_loss + ar_weight * ar_loss + ctc_weight * ctc_loss with torch.no_grad(): acc = ((logits.argmax(-1) == x0) & train_pos).float().sum() / train_pos.float().sum().clamp_min(1.0) return {"loss": loss, "token_acc": acc.detach(), "diff_loss": diff_loss.detach(), "ar_loss": ar_loss.detach(), "ctc_loss": ctc_loss.detach()} # ---- diagnostic: CTC greedy decode straight from the projector (is audio grounded?) ---- @torch.no_grad() def ctc_greedy(self, prompt_ids, prompt_mask, mel, audio_real_lengths): audio = self._project(mel) # [B, T_a, H] ids = self.lm_head(audio.to(self.lm_head.weight.dtype)).argmax(-1) # [B, T_a] out = [] for b in range(ids.shape[0]): seq = ids[b, : int(audio_real_lengths[b])].tolist() collapsed, prev = [], None for t in seq: # collapse repeats + drop blank(0) if t != prev and t != 0: collapsed.append(t) prev = t out.append(collapsed) return out # ---- inference: audio-conditioned parallel denoising ---- def _mask_mapping(self, prompt_mask, cache, canvas_len, device): dec_mask = torch.cat( [prompt_mask, torch.ones(prompt_mask.shape[0], canvas_len, device=device, dtype=prompt_mask.dtype)], dim=1) return self.decoder.create_diffusion_decoder_attention_mask( config=self.base.config.text_config, inputs_embeds=torch.empty(prompt_mask.shape[0], canvas_len, 1, device=device), past_key_values=cache, decoder_attention_mask=dec_mask) @torch.no_grad() def generate(self, prompt_ids, prompt_mask, mel, *, canvas_len: int = 256, max_steps: int = 48, entropy_bound: float = 0.1, t_min: float = 0.4, t_max: float = 0.8, confidence_threshold: float = 0.005, stability_threshold: int = 1, guidance_weight: float = 1.0): """DiffusionGemma's uniform-diffusion sampler, conditioned on audio, with optional classifier-free guidance (CFG). guidance_weight w>1 amplifies the audio's effect: at each step we combine a conditional pass (attends to audio KV) and an unconditional pass (same cache but the decoder mask hides the audio KV -> pure language prior): guided = uncond + w * (cond - uncond) This fixes the 'conditioning ignored' failure WITHOUT retraining — the unconditional branch is exactly the prior the model already produces. Returns (argmax_canvas [B, canvas_len], steps_used). """ from transformers.models.diffusion_gemma.generation_diffusion_gemma import ( EntropyBoundSampler, EntropyBoundSamplerConfig, LinearTemperatureScheduleLogitsProcessor, StableAndConfidentStoppingCriteria, ) device = prompt_ids.device B = prompt_ids.shape[0] cache = self._encode(prompt_ids, prompt_mask, self._project(mel)) cfg_on = guidance_weight != 1.0 sampler = EntropyBoundSampler( EntropyBoundSamplerConfig(entropy_bound=entropy_bound), canvas_len, self.cfg.vocab_size, max_steps) temp = LinearTemperatureScheduleLogitsProcessor(t_min, t_max, max_steps) stopping = StableAndConfidentStoppingCriteria(stability_threshold, confidence_threshold) mm_cond = self._mask_mapping(prompt_mask, cache, canvas_len, device) if cfg_on: # unconditional = hide audio KV from the decoder pmask_uncond = prompt_mask.clone() pmask_uncond[prompt_ids == self.cfg.audio_token_id] = 0 mm_uncond = self._mask_mapping(pmask_uncond, cache, canvas_len, device) current = sampler.initialize_canvas(B, device) argmax_canvas = current.clone() self_cond = None steps_used = 0 for step in range(max_steps): cond = self._decode_logits(cache, current, mm_cond, self_conditioning_logits=self_cond) if cfg_on: uncond = self._decode_logits(cache, current, mm_uncond, self_conditioning_logits=self_cond) logits = uncond + guidance_weight * (cond - uncond) else: logits = cond processed = temp(current, logits, cur_step=step) probs = torch.softmax(processed, dim=-1, dtype=torch.float32) denoiser = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(B, canvas_len) argmax_canvas = processed.argmax(-1) accepted = sampler.accept_canvas(current, denoiser, processed, step) current = sampler.renoise_canvas(accepted, step) self_cond = cond steps_used = step + 1 if bool(stopping(argmax_canvas, processed).all()): break return argmax_canvas, steps_used