HV-Khurdula's picture
Publish diffusion-gemma-asr-small (fleurs2/projector_ep3.pt)
583a90c verified
Raw
History Blame Contribute Delete
16.1 kB
"""Audio-native DiffusionGemma: graft an encoder-free audio pathway onto the
real `DiffusionGemmaForBlockDiffusion`, and train it with the model's *own*
uniform discrete-diffusion objective.
Verified mechanism (from transformers `models/diffusion_gemma`, 2026-06):
* Encoder–decoder model. The ENCODER (causal) turns the prompt `input_ids` into
a read-only KV cache; the DECODER refines a `decoder_input_ids` canvas with
bidirectional self-attention + cross-attention to that cache. Encoder/decoder
transformer weights are tied.
* Multimodal inputs are injected by scattering projected features into the
placeholder-token positions of the encoder's `inputs_embeds` (the vision path
uses `image_token_id`; we add the audio analog at `AUDIO_TOKEN_ID`).
* Generation is UNIFORM discrete diffusion: canvas starts as uniform-random
tokens; each step accepts low-entropy predictions and *renoises the rest to
fresh uniform-random tokens*. There is NO absorbing <mask> state.
So training = denoising score matching against uniform corruption: take the clean
transcript canvas x0, replace a fraction γ of positions with uniform-random
tokens to get x_t, and train the model to predict x0 at the corrupted positions,
conditioned on the audio (in the encoder cache).
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from audio import AudioProjector
@dataclass
class AudioDiffusionConfig:
model_dir: str
whisper_id: str = "openai/whisper-small"
whisper_dim: int = 768
d_model: int = 2816
vocab_size: int = 262144
boa_token_id: int = 256000
audio_token_id: int = 258881
eoa_token_id: int = 258883
final_logit_softcapping: float = 30.0
subsample_factor: int = 8
proj_hidden: int = 1280
class AudioDiffusionGemma(nn.Module):
def __init__(self, base, cfg: AudioDiffusionConfig, whisper=None):
super().__init__()
self.base = base # DiffusionGemmaForBlockDiffusion
self.whisper = whisper # frozen Whisper encoder (feature extractor)
self.cfg = cfg
self.projector = AudioProjector(
d_model=cfg.d_model,
in_dim=cfg.whisper_dim,
hidden=cfg.proj_hidden,
subsample_factor=cfg.subsample_factor,
)
# Sub-module handles (avoid the vision tower entirely).
self.text_encoder = base.model.encoder.language_model # DiffusionGemmaEncoderTextModel
self.decoder = base.model.decoder # DiffusionGemmaDecoderModel
self.embed = base.get_input_embeddings() # scaled word embedding (tied)
self.lm_head = base.lm_head
# ---- construction ----
@classmethod
def from_pretrained(cls, cfg: AudioDiffusionConfig, dtype=torch.bfloat16, device="cuda"):
import transformers
from transformers import AutoConfig, WhisperModel
hfcfg = AutoConfig.from_pretrained(cfg.model_dir)
ModelClass = getattr(transformers, hfcfg.architectures[0])
base = ModelClass.from_pretrained(cfg.model_dir, dtype=dtype, device_map=device)
# Frozen Whisper encoder (acoustic feature extractor; NOT a decoder).
whisper = WhisperModel.from_pretrained(cfg.whisper_id, dtype=dtype).encoder
whisper = whisper.to(device).eval()
for p in whisper.parameters():
p.requires_grad_(False)
model = cls(base, cfg, whisper=whisper)
# Keep the trainable projector in fp32 for stable AdamW; backbone stays bf16.
model.projector = model.projector.to(device=device, dtype=torch.float32)
return model
def freeze_backbone(self):
"""Stage-1: only the audio projector trains."""
for p in self.base.parameters():
p.requires_grad_(False)
for p in self.projector.parameters():
p.requires_grad_(True)
def apply_lora(self, r: int = 16, alpha: int = 32, dropout: float = 0.05):
"""LoRA on ENCODER + DECODER attention (so the encoder learns to integrate
audio into the cache AND the decoder learns to attend to it), plus the
trainable projector. Experts/vision stay frozen.
"""
from peft import LoraConfig, get_peft_model
target = r".*(decoder|encoder\.language_model)\.layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)$"
lcfg = LoraConfig(r=r, lora_alpha=alpha, lora_dropout=dropout,
target_modules=target, bias="none")
self.base = get_peft_model(self.base, lcfg) # freezes base, enables LoRA; submodule objects unchanged
for p in self.projector.parameters():
p.requires_grad_(True)
return self
def lora_state_dict(self):
from peft import get_peft_model_state_dict
return get_peft_model_state_dict(self.base)
def trainable_parameters(self):
return [p for p in self.parameters() if p.requires_grad]
# ---- audio injection into the encoder ----
def _project(self, mel):
"""Frozen Whisper -> trainable projector -> [B, T_a, H] audio embeds."""
wdtype = next(self.whisper.parameters()).dtype
with torch.no_grad(): # encoder frozen -> no grad/activations
wfeat = self.whisper(mel.to(wdtype)).last_hidden_state # [B, 1500, 768]
return self.projector(wfeat.float()) # [B, T_a, H] (fp32 projector)
def _scatter_audio(self, input_ids, audio):
"""Embed token ids, then overwrite AUDIO_TOKEN_ID slots with audio embeds."""
H = audio.shape[-1]
flat = audio.reshape(-1, H)
embeds = self.embed(input_ids)
audio_pos = input_ids == self.cfg.audio_token_id
assert int(audio_pos.sum()) == flat.shape[0], (
f"audio placeholders ({int(audio_pos.sum())}) != audio embeds ({flat.shape[0]})"
)
return embeds.masked_scatter(audio_pos.unsqueeze(-1), flat.to(embeds.dtype))
def _encode(self, prompt_ids, prompt_mask, audio):
"""Scatter audio into the prompt and run the encoder -> KV cache."""
embeds = self._scatter_audio(prompt_ids, audio)
enc = self.text_encoder(inputs_embeds=embeds, attention_mask=prompt_mask)
return enc.past_key_values
def _decode_logits(self, cache, canvas_ids, decoder_attention_mask, self_conditioning_logits=None):
dec = self.decoder(
decoder_input_ids=canvas_ids,
past_key_values=cache,
decoder_attention_mask=decoder_attention_mask,
self_conditioning_logits=self_conditioning_logits,
)
logits = self.lm_head(dec.last_hidden_state).float()
sc = self.cfg.final_logit_softcapping
logits = torch.tanh(logits / sc) * sc
return logits
def _softcap(self, logits):
sc = self.cfg.final_logit_softcapping
return torch.tanh(logits.float() / sc) * sc
# ---- training step ----
def forward(self, batch, gamma_min: float = 0.0, high_gamma_frac: float = 0.0,
ar_weight: float = 0.0, ctc_weight: float = 0.0, weight_by_gamma: bool = False):
"""Joint training step: diffusion loss + optional autoregressive aux loss.
The diffusion objective is a weak teacher for audio conditioning (the
frozen LM prior out-competes the diluted, whole-canvas gradient). The AR
aux loss teacher-forces the transcript through DiffusionGemma's OWN AR
encoder (strong per-token audio→text gradient); because the encoder shares
weights with the diffusion decoder + the same projector, that grounding
transfers to parallel denoising. `ar_weight` mixes them.
"""
prompt_ids = batch["prompt_ids"]
prompt_mask = batch["prompt_mask"]
mel = batch["mel"]
x0 = batch["canvas"] # [B, L] clean target (tokens, EOS, PAD)
loss_mask = batch["canvas_loss_mask"]
B, L = x0.shape
P = prompt_ids.shape[1]
device = x0.device
pad = 0
audio = self._project(mel) # [B, T_a, H] — shared by both branches
# --- diffusion branch (uniform corruption q(x_t|x0)) ---
gamma = torch.empty(B, 1, device=device).uniform_(gamma_min, 1.0)
if high_gamma_frac > 0:
force_full = torch.rand(B, 1, device=device) < high_gamma_frac
gamma = torch.where(force_full, torch.ones_like(gamma), gamma)
corrupt = torch.rand(B, L, device=device) < gamma
rand_tok = torch.randint(0, self.cfg.vocab_size, (B, L), device=device)
x_t = torch.where(corrupt, rand_tok, x0)
cache = self._encode(prompt_ids, prompt_mask, audio)
dec_mask = torch.cat([prompt_mask, torch.ones(B, L, device=device, dtype=prompt_mask.dtype)], dim=1)
logits = self._decode_logits(cache, x_t, dec_mask) # [B, L, V]
train_pos = corrupt & loss_mask.bool()
if train_pos.sum() == 0:
train_pos = loss_mask.bool()
ce = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), x0.reshape(-1),
reduction="none").reshape(B, L)
if weight_by_gamma:
ce = ce / gamma.clamp_min(1e-3)
diff_loss = (ce * train_pos.float()).sum() / train_pos.float().sum().clamp_min(1.0)
# --- AR auxiliary branch (teacher-forced transcript through the AR encoder) ---
ar_loss = torch.zeros((), device=device)
if ar_weight > 0:
ar_ids = torch.cat([prompt_ids, x0], dim=1) # [B, P+L]
ar_real = (x0 != pad) # transcript+EOS
ar_attn = torch.cat([prompt_mask, ar_real.to(prompt_mask.dtype)], dim=1)
ar_embeds = self._scatter_audio(ar_ids, audio)
ar_hidden = self.text_encoder(inputs_embeds=ar_embeds, attention_mask=ar_attn).last_hidden_state
ar_logits = self._softcap(self.lm_head(ar_hidden[:, P - 1:P + L - 1, :])) # predicts x0
ar_ce = F.cross_entropy(ar_logits.reshape(-1, ar_logits.shape[-1]), x0.reshape(-1),
reduction="none").reshape(B, L)
ar_loss = (ar_ce * ar_real.float()).sum() / ar_real.float().sum().clamp_min(1.0)
# --- CTC auxiliary branch (direct projector supervision; breaks the
# chicken-and-egg by making audio embeds transcript-predictive in the
# LLM token space via the frozen lm_head, independent of attention) ---
ctc_loss = torch.zeros((), device=device)
if ctc_weight > 0:
ctc_logits = self.lm_head(audio.to(self.lm_head.weight.dtype)) # [B, T_a, V]
log_probs = ctc_logits.float().log_softmax(-1).transpose(0, 1) # [T_a, B, V]
ctc_loss = F.ctc_loss(
log_probs, batch["ctc_targets"],
batch["audio_real_lengths"], batch["ctc_target_lengths"],
blank=0, zero_infinity=True)
loss = diff_loss + ar_weight * ar_loss + ctc_weight * ctc_loss
with torch.no_grad():
acc = ((logits.argmax(-1) == x0) & train_pos).float().sum() / train_pos.float().sum().clamp_min(1.0)
return {"loss": loss, "token_acc": acc.detach(),
"diff_loss": diff_loss.detach(), "ar_loss": ar_loss.detach(),
"ctc_loss": ctc_loss.detach()}
# ---- diagnostic: CTC greedy decode straight from the projector (is audio grounded?) ----
@torch.no_grad()
def ctc_greedy(self, prompt_ids, prompt_mask, mel, audio_real_lengths):
audio = self._project(mel) # [B, T_a, H]
ids = self.lm_head(audio.to(self.lm_head.weight.dtype)).argmax(-1) # [B, T_a]
out = []
for b in range(ids.shape[0]):
seq = ids[b, : int(audio_real_lengths[b])].tolist()
collapsed, prev = [], None
for t in seq: # collapse repeats + drop blank(0)
if t != prev and t != 0:
collapsed.append(t)
prev = t
out.append(collapsed)
return out
# ---- inference: audio-conditioned parallel denoising ----
def _mask_mapping(self, prompt_mask, cache, canvas_len, device):
dec_mask = torch.cat(
[prompt_mask, torch.ones(prompt_mask.shape[0], canvas_len, device=device, dtype=prompt_mask.dtype)], dim=1)
return self.decoder.create_diffusion_decoder_attention_mask(
config=self.base.config.text_config,
inputs_embeds=torch.empty(prompt_mask.shape[0], canvas_len, 1, device=device),
past_key_values=cache, decoder_attention_mask=dec_mask)
@torch.no_grad()
def generate(self, prompt_ids, prompt_mask, mel, *,
canvas_len: int = 256, max_steps: int = 48, entropy_bound: float = 0.1,
t_min: float = 0.4, t_max: float = 0.8,
confidence_threshold: float = 0.005, stability_threshold: int = 1,
guidance_weight: float = 1.0):
"""DiffusionGemma's uniform-diffusion sampler, conditioned on audio, with
optional classifier-free guidance (CFG).
guidance_weight w>1 amplifies the audio's effect: at each step we combine a
conditional pass (attends to audio KV) and an unconditional pass (same cache
but the decoder mask hides the audio KV -> pure language prior):
guided = uncond + w * (cond - uncond)
This fixes the 'conditioning ignored' failure WITHOUT retraining — the
unconditional branch is exactly the prior the model already produces.
Returns (argmax_canvas [B, canvas_len], steps_used).
"""
from transformers.models.diffusion_gemma.generation_diffusion_gemma import (
EntropyBoundSampler, EntropyBoundSamplerConfig,
LinearTemperatureScheduleLogitsProcessor, StableAndConfidentStoppingCriteria,
)
device = prompt_ids.device
B = prompt_ids.shape[0]
cache = self._encode(prompt_ids, prompt_mask, self._project(mel))
cfg_on = guidance_weight != 1.0
sampler = EntropyBoundSampler(
EntropyBoundSamplerConfig(entropy_bound=entropy_bound), canvas_len, self.cfg.vocab_size, max_steps)
temp = LinearTemperatureScheduleLogitsProcessor(t_min, t_max, max_steps)
stopping = StableAndConfidentStoppingCriteria(stability_threshold, confidence_threshold)
mm_cond = self._mask_mapping(prompt_mask, cache, canvas_len, device)
if cfg_on: # unconditional = hide audio KV from the decoder
pmask_uncond = prompt_mask.clone()
pmask_uncond[prompt_ids == self.cfg.audio_token_id] = 0
mm_uncond = self._mask_mapping(pmask_uncond, cache, canvas_len, device)
current = sampler.initialize_canvas(B, device)
argmax_canvas = current.clone()
self_cond = None
steps_used = 0
for step in range(max_steps):
cond = self._decode_logits(cache, current, mm_cond, self_conditioning_logits=self_cond)
if cfg_on:
uncond = self._decode_logits(cache, current, mm_uncond, self_conditioning_logits=self_cond)
logits = uncond + guidance_weight * (cond - uncond)
else:
logits = cond
processed = temp(current, logits, cur_step=step)
probs = torch.softmax(processed, dim=-1, dtype=torch.float32)
denoiser = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(B, canvas_len)
argmax_canvas = processed.argmax(-1)
accepted = sampler.accept_canvas(current, denoiser, processed, step)
current = sampler.renoise_canvas(accepted, step)
self_cond = cond
steps_used = step + 1
if bool(stopping(argmax_canvas, processed).all()):
break
return argmax_canvas, steps_used