Instructions to use interfaze-ai/diffusion-gemma-asr-small with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use interfaze-ai/diffusion-gemma-asr-small with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="interfaze-ai/diffusion-gemma-asr-small")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("interfaze-ai/diffusion-gemma-asr-small", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Audio-native DiffusionGemma: graft an encoder-free audio pathway onto the | |
| real `DiffusionGemmaForBlockDiffusion`, and train it with the model's *own* | |
| uniform discrete-diffusion objective. | |
| Verified mechanism (from transformers `models/diffusion_gemma`, 2026-06): | |
| * Encoder–decoder model. The ENCODER (causal) turns the prompt `input_ids` into | |
| a read-only KV cache; the DECODER refines a `decoder_input_ids` canvas with | |
| bidirectional self-attention + cross-attention to that cache. Encoder/decoder | |
| transformer weights are tied. | |
| * Multimodal inputs are injected by scattering projected features into the | |
| placeholder-token positions of the encoder's `inputs_embeds` (the vision path | |
| uses `image_token_id`; we add the audio analog at `AUDIO_TOKEN_ID`). | |
| * Generation is UNIFORM discrete diffusion: canvas starts as uniform-random | |
| tokens; each step accepts low-entropy predictions and *renoises the rest to | |
| fresh uniform-random tokens*. There is NO absorbing <mask> state. | |
| So training = denoising score matching against uniform corruption: take the clean | |
| transcript canvas x0, replace a fraction γ of positions with uniform-random | |
| tokens to get x_t, and train the model to predict x0 at the corrupted positions, | |
| conditioned on the audio (in the encoder cache). | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from audio import AudioProjector | |
| class AudioDiffusionConfig: | |
| model_dir: str | |
| whisper_id: str = "openai/whisper-small" | |
| whisper_dim: int = 768 | |
| d_model: int = 2816 | |
| vocab_size: int = 262144 | |
| boa_token_id: int = 256000 | |
| audio_token_id: int = 258881 | |
| eoa_token_id: int = 258883 | |
| final_logit_softcapping: float = 30.0 | |
| subsample_factor: int = 8 | |
| proj_hidden: int = 1280 | |
| class AudioDiffusionGemma(nn.Module): | |
| def __init__(self, base, cfg: AudioDiffusionConfig, whisper=None): | |
| super().__init__() | |
| self.base = base # DiffusionGemmaForBlockDiffusion | |
| self.whisper = whisper # frozen Whisper encoder (feature extractor) | |
| self.cfg = cfg | |
| self.projector = AudioProjector( | |
| d_model=cfg.d_model, | |
| in_dim=cfg.whisper_dim, | |
| hidden=cfg.proj_hidden, | |
| subsample_factor=cfg.subsample_factor, | |
| ) | |
| # Sub-module handles (avoid the vision tower entirely). | |
| self.text_encoder = base.model.encoder.language_model # DiffusionGemmaEncoderTextModel | |
| self.decoder = base.model.decoder # DiffusionGemmaDecoderModel | |
| self.embed = base.get_input_embeddings() # scaled word embedding (tied) | |
| self.lm_head = base.lm_head | |
| # ---- construction ---- | |
| def from_pretrained(cls, cfg: AudioDiffusionConfig, dtype=torch.bfloat16, device="cuda"): | |
| import transformers | |
| from transformers import AutoConfig, WhisperModel | |
| hfcfg = AutoConfig.from_pretrained(cfg.model_dir) | |
| ModelClass = getattr(transformers, hfcfg.architectures[0]) | |
| base = ModelClass.from_pretrained(cfg.model_dir, dtype=dtype, device_map=device) | |
| # Frozen Whisper encoder (acoustic feature extractor; NOT a decoder). | |
| whisper = WhisperModel.from_pretrained(cfg.whisper_id, dtype=dtype).encoder | |
| whisper = whisper.to(device).eval() | |
| for p in whisper.parameters(): | |
| p.requires_grad_(False) | |
| model = cls(base, cfg, whisper=whisper) | |
| # Keep the trainable projector in fp32 for stable AdamW; backbone stays bf16. | |
| model.projector = model.projector.to(device=device, dtype=torch.float32) | |
| return model | |
| def freeze_backbone(self): | |
| """Stage-1: only the audio projector trains.""" | |
| for p in self.base.parameters(): | |
| p.requires_grad_(False) | |
| for p in self.projector.parameters(): | |
| p.requires_grad_(True) | |
| def apply_lora(self, r: int = 16, alpha: int = 32, dropout: float = 0.05): | |
| """LoRA on ENCODER + DECODER attention (so the encoder learns to integrate | |
| audio into the cache AND the decoder learns to attend to it), plus the | |
| trainable projector. Experts/vision stay frozen. | |
| """ | |
| from peft import LoraConfig, get_peft_model | |
| target = r".*(decoder|encoder\.language_model)\.layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)$" | |
| lcfg = LoraConfig(r=r, lora_alpha=alpha, lora_dropout=dropout, | |
| target_modules=target, bias="none") | |
| self.base = get_peft_model(self.base, lcfg) # freezes base, enables LoRA; submodule objects unchanged | |
| for p in self.projector.parameters(): | |
| p.requires_grad_(True) | |
| return self | |
| def lora_state_dict(self): | |
| from peft import get_peft_model_state_dict | |
| return get_peft_model_state_dict(self.base) | |
| def trainable_parameters(self): | |
| return [p for p in self.parameters() if p.requires_grad] | |
| # ---- audio injection into the encoder ---- | |
| def _project(self, mel): | |
| """Frozen Whisper -> trainable projector -> [B, T_a, H] audio embeds.""" | |
| wdtype = next(self.whisper.parameters()).dtype | |
| with torch.no_grad(): # encoder frozen -> no grad/activations | |
| wfeat = self.whisper(mel.to(wdtype)).last_hidden_state # [B, 1500, 768] | |
| return self.projector(wfeat.float()) # [B, T_a, H] (fp32 projector) | |
| def _scatter_audio(self, input_ids, audio): | |
| """Embed token ids, then overwrite AUDIO_TOKEN_ID slots with audio embeds.""" | |
| H = audio.shape[-1] | |
| flat = audio.reshape(-1, H) | |
| embeds = self.embed(input_ids) | |
| audio_pos = input_ids == self.cfg.audio_token_id | |
| assert int(audio_pos.sum()) == flat.shape[0], ( | |
| f"audio placeholders ({int(audio_pos.sum())}) != audio embeds ({flat.shape[0]})" | |
| ) | |
| return embeds.masked_scatter(audio_pos.unsqueeze(-1), flat.to(embeds.dtype)) | |
| def _encode(self, prompt_ids, prompt_mask, audio): | |
| """Scatter audio into the prompt and run the encoder -> KV cache.""" | |
| embeds = self._scatter_audio(prompt_ids, audio) | |
| enc = self.text_encoder(inputs_embeds=embeds, attention_mask=prompt_mask) | |
| return enc.past_key_values | |
| def _decode_logits(self, cache, canvas_ids, decoder_attention_mask, self_conditioning_logits=None): | |
| dec = self.decoder( | |
| decoder_input_ids=canvas_ids, | |
| past_key_values=cache, | |
| decoder_attention_mask=decoder_attention_mask, | |
| self_conditioning_logits=self_conditioning_logits, | |
| ) | |
| logits = self.lm_head(dec.last_hidden_state).float() | |
| sc = self.cfg.final_logit_softcapping | |
| logits = torch.tanh(logits / sc) * sc | |
| return logits | |
| def _softcap(self, logits): | |
| sc = self.cfg.final_logit_softcapping | |
| return torch.tanh(logits.float() / sc) * sc | |
| # ---- training step ---- | |
| def forward(self, batch, gamma_min: float = 0.0, high_gamma_frac: float = 0.0, | |
| ar_weight: float = 0.0, ctc_weight: float = 0.0, weight_by_gamma: bool = False): | |
| """Joint training step: diffusion loss + optional autoregressive aux loss. | |
| The diffusion objective is a weak teacher for audio conditioning (the | |
| frozen LM prior out-competes the diluted, whole-canvas gradient). The AR | |
| aux loss teacher-forces the transcript through DiffusionGemma's OWN AR | |
| encoder (strong per-token audio→text gradient); because the encoder shares | |
| weights with the diffusion decoder + the same projector, that grounding | |
| transfers to parallel denoising. `ar_weight` mixes them. | |
| """ | |
| prompt_ids = batch["prompt_ids"] | |
| prompt_mask = batch["prompt_mask"] | |
| mel = batch["mel"] | |
| x0 = batch["canvas"] # [B, L] clean target (tokens, EOS, PAD) | |
| loss_mask = batch["canvas_loss_mask"] | |
| B, L = x0.shape | |
| P = prompt_ids.shape[1] | |
| device = x0.device | |
| pad = 0 | |
| audio = self._project(mel) # [B, T_a, H] — shared by both branches | |
| # --- diffusion branch (uniform corruption q(x_t|x0)) --- | |
| gamma = torch.empty(B, 1, device=device).uniform_(gamma_min, 1.0) | |
| if high_gamma_frac > 0: | |
| force_full = torch.rand(B, 1, device=device) < high_gamma_frac | |
| gamma = torch.where(force_full, torch.ones_like(gamma), gamma) | |
| corrupt = torch.rand(B, L, device=device) < gamma | |
| rand_tok = torch.randint(0, self.cfg.vocab_size, (B, L), device=device) | |
| x_t = torch.where(corrupt, rand_tok, x0) | |
| cache = self._encode(prompt_ids, prompt_mask, audio) | |
| dec_mask = torch.cat([prompt_mask, torch.ones(B, L, device=device, dtype=prompt_mask.dtype)], dim=1) | |
| logits = self._decode_logits(cache, x_t, dec_mask) # [B, L, V] | |
| train_pos = corrupt & loss_mask.bool() | |
| if train_pos.sum() == 0: | |
| train_pos = loss_mask.bool() | |
| ce = F.cross_entropy(logits.reshape(-1, logits.shape[-1]), x0.reshape(-1), | |
| reduction="none").reshape(B, L) | |
| if weight_by_gamma: | |
| ce = ce / gamma.clamp_min(1e-3) | |
| diff_loss = (ce * train_pos.float()).sum() / train_pos.float().sum().clamp_min(1.0) | |
| # --- AR auxiliary branch (teacher-forced transcript through the AR encoder) --- | |
| ar_loss = torch.zeros((), device=device) | |
| if ar_weight > 0: | |
| ar_ids = torch.cat([prompt_ids, x0], dim=1) # [B, P+L] | |
| ar_real = (x0 != pad) # transcript+EOS | |
| ar_attn = torch.cat([prompt_mask, ar_real.to(prompt_mask.dtype)], dim=1) | |
| ar_embeds = self._scatter_audio(ar_ids, audio) | |
| ar_hidden = self.text_encoder(inputs_embeds=ar_embeds, attention_mask=ar_attn).last_hidden_state | |
| ar_logits = self._softcap(self.lm_head(ar_hidden[:, P - 1:P + L - 1, :])) # predicts x0 | |
| ar_ce = F.cross_entropy(ar_logits.reshape(-1, ar_logits.shape[-1]), x0.reshape(-1), | |
| reduction="none").reshape(B, L) | |
| ar_loss = (ar_ce * ar_real.float()).sum() / ar_real.float().sum().clamp_min(1.0) | |
| # --- CTC auxiliary branch (direct projector supervision; breaks the | |
| # chicken-and-egg by making audio embeds transcript-predictive in the | |
| # LLM token space via the frozen lm_head, independent of attention) --- | |
| ctc_loss = torch.zeros((), device=device) | |
| if ctc_weight > 0: | |
| ctc_logits = self.lm_head(audio.to(self.lm_head.weight.dtype)) # [B, T_a, V] | |
| log_probs = ctc_logits.float().log_softmax(-1).transpose(0, 1) # [T_a, B, V] | |
| ctc_loss = F.ctc_loss( | |
| log_probs, batch["ctc_targets"], | |
| batch["audio_real_lengths"], batch["ctc_target_lengths"], | |
| blank=0, zero_infinity=True) | |
| loss = diff_loss + ar_weight * ar_loss + ctc_weight * ctc_loss | |
| with torch.no_grad(): | |
| acc = ((logits.argmax(-1) == x0) & train_pos).float().sum() / train_pos.float().sum().clamp_min(1.0) | |
| return {"loss": loss, "token_acc": acc.detach(), | |
| "diff_loss": diff_loss.detach(), "ar_loss": ar_loss.detach(), | |
| "ctc_loss": ctc_loss.detach()} | |
| # ---- diagnostic: CTC greedy decode straight from the projector (is audio grounded?) ---- | |
| def ctc_greedy(self, prompt_ids, prompt_mask, mel, audio_real_lengths): | |
| audio = self._project(mel) # [B, T_a, H] | |
| ids = self.lm_head(audio.to(self.lm_head.weight.dtype)).argmax(-1) # [B, T_a] | |
| out = [] | |
| for b in range(ids.shape[0]): | |
| seq = ids[b, : int(audio_real_lengths[b])].tolist() | |
| collapsed, prev = [], None | |
| for t in seq: # collapse repeats + drop blank(0) | |
| if t != prev and t != 0: | |
| collapsed.append(t) | |
| prev = t | |
| out.append(collapsed) | |
| return out | |
| # ---- inference: audio-conditioned parallel denoising ---- | |
| def _mask_mapping(self, prompt_mask, cache, canvas_len, device): | |
| dec_mask = torch.cat( | |
| [prompt_mask, torch.ones(prompt_mask.shape[0], canvas_len, device=device, dtype=prompt_mask.dtype)], dim=1) | |
| return self.decoder.create_diffusion_decoder_attention_mask( | |
| config=self.base.config.text_config, | |
| inputs_embeds=torch.empty(prompt_mask.shape[0], canvas_len, 1, device=device), | |
| past_key_values=cache, decoder_attention_mask=dec_mask) | |
| def generate(self, prompt_ids, prompt_mask, mel, *, | |
| canvas_len: int = 256, max_steps: int = 48, entropy_bound: float = 0.1, | |
| t_min: float = 0.4, t_max: float = 0.8, | |
| confidence_threshold: float = 0.005, stability_threshold: int = 1, | |
| guidance_weight: float = 1.0): | |
| """DiffusionGemma's uniform-diffusion sampler, conditioned on audio, with | |
| optional classifier-free guidance (CFG). | |
| guidance_weight w>1 amplifies the audio's effect: at each step we combine a | |
| conditional pass (attends to audio KV) and an unconditional pass (same cache | |
| but the decoder mask hides the audio KV -> pure language prior): | |
| guided = uncond + w * (cond - uncond) | |
| This fixes the 'conditioning ignored' failure WITHOUT retraining — the | |
| unconditional branch is exactly the prior the model already produces. | |
| Returns (argmax_canvas [B, canvas_len], steps_used). | |
| """ | |
| from transformers.models.diffusion_gemma.generation_diffusion_gemma import ( | |
| EntropyBoundSampler, EntropyBoundSamplerConfig, | |
| LinearTemperatureScheduleLogitsProcessor, StableAndConfidentStoppingCriteria, | |
| ) | |
| device = prompt_ids.device | |
| B = prompt_ids.shape[0] | |
| cache = self._encode(prompt_ids, prompt_mask, self._project(mel)) | |
| cfg_on = guidance_weight != 1.0 | |
| sampler = EntropyBoundSampler( | |
| EntropyBoundSamplerConfig(entropy_bound=entropy_bound), canvas_len, self.cfg.vocab_size, max_steps) | |
| temp = LinearTemperatureScheduleLogitsProcessor(t_min, t_max, max_steps) | |
| stopping = StableAndConfidentStoppingCriteria(stability_threshold, confidence_threshold) | |
| mm_cond = self._mask_mapping(prompt_mask, cache, canvas_len, device) | |
| if cfg_on: # unconditional = hide audio KV from the decoder | |
| pmask_uncond = prompt_mask.clone() | |
| pmask_uncond[prompt_ids == self.cfg.audio_token_id] = 0 | |
| mm_uncond = self._mask_mapping(pmask_uncond, cache, canvas_len, device) | |
| current = sampler.initialize_canvas(B, device) | |
| argmax_canvas = current.clone() | |
| self_cond = None | |
| steps_used = 0 | |
| for step in range(max_steps): | |
| cond = self._decode_logits(cache, current, mm_cond, self_conditioning_logits=self_cond) | |
| if cfg_on: | |
| uncond = self._decode_logits(cache, current, mm_uncond, self_conditioning_logits=self_cond) | |
| logits = uncond + guidance_weight * (cond - uncond) | |
| else: | |
| logits = cond | |
| processed = temp(current, logits, cur_step=step) | |
| probs = torch.softmax(processed, dim=-1, dtype=torch.float32) | |
| denoiser = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(B, canvas_len) | |
| argmax_canvas = processed.argmax(-1) | |
| accepted = sampler.accept_canvas(current, denoiser, processed, step) | |
| current = sampler.renoise_canvas(accepted, step) | |
| self_cond = cond | |
| steps_used = step + 1 | |
| if bool(stopping(argmax_canvas, processed).all()): | |
| break | |
| return argmax_canvas, steps_used | |