import os import time import jax import jax.numpy as jnp import numpy as np import librosa import gradio as gr from pathlib import Path from flax import nnx import orbax.checkpoint as ocp os.environ["JAX_PLATFORMS"] = "cpu" from conformer.tokenizer import Tokenizer from conformer.config import ConformerConfig, FeaturizerConfig from conformer.model import ZipformerEncoder def load_model(): feat_config = FeaturizerConfig() conf_config = ConformerConfig() tokenizer = Tokenizer.load_tokenizer(Path("./checkpoints/tokenizer.pkl")) token_count = ( tokenizer.vocab_size if hasattr(tokenizer, "vocab_size") else len(tokenizer.id_to_char) ) model = ZipformerEncoder( token_count=token_count, num_layers=conf_config.num_encoder_layers, d_model=conf_config.encoder_dim, num_head=conf_config.num_attention_heads, dropout=0.0, feed_forward_expansion_factor=conf_config.feed_forward_expansion_factor, d_input=feat_config.n_mels, sample_rate=feat_config.sampling_rate, n_fft=feat_config.n_fft, n_window_size=feat_config.win_length, n_window_stride=feat_config.hop_length, dtype=jnp.float32, rngs=nnx.Rngs(0), ) checkpoint_dir = os.path.abspath("./checkpoints") mngr = ocp.CheckpointManager(checkpoint_dir, options=ocp.CheckpointManagerOptions()) latest_step = mngr.latest_step() if latest_step is None: raise ValueError("No checkpoints found.") print(f"Restoring model from step {latest_step}...") restored = mngr.restore( latest_step, args=ocp.args.Composite( model=ocp.args.StandardRestore(nnx.state(model)), ), ) nnx.update(model, restored.model) mngr.close() return model, tokenizer, feat_config print("Loading model...") MODEL, TOKENIZER, FEAT_CONFIG = load_model() MODEL.mel_spectogram.spec_augment = False MODEL.mel_spectogram.normalize = True def trim_silence(audio, threshold=0.01): energy = np.abs(audio) start = np.where(energy > threshold)[0][0] end = np.where(energy > threshold)[0][-1] return audio[start:end] def ctc_decode(ids, tokenizer): collapsed_ids = [] prev = -1 for _id in ids: _id = int(_id) if _id != prev: if _id != tokenizer.blank_id and _id != getattr( tokenizer, "padding_id", -1 ): collapsed_ids.append(_id) prev = _id return "".join([tokenizer.id_to_char.get(_id, f"[{_id}]") for _id in collapsed_ids]) @nnx.jit def jit_inference(model: ZipformerEncoder, audio: jnp.ndarray, lengths: jnp.ndarray): return model(audio, training=False, inputs_lengths=lengths) def transcribe(audio_input): if audio_input is None: return "Please provide an audio file.", "" sr, audio = audio_input if audio.dtype.kind in "iu": max_val = float(2 ** (8 * audio.itemsize - 1)) audio = audio.astype(np.float32) / max_val else: audio = audio.astype(np.float32) # Resample if sr != FEAT_CONFIG.sampling_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=FEAT_CONFIG.sampling_rate) sr = FEAT_CONFIG.sampling_rate # === FIXES FOR MIC DOMAIN SHIFT === audio = audio - np.mean(audio) # DC removal # Pre-emphasis (this is the #1 thing missing for real mics vs Common Voice) if len(audio) > 1: audio = np.append(audio[0], audio[1:] - 0.97 * audio[:-1]) # RMS normalization rms = np.sqrt(np.mean(audio**2) + 1e-8) target_rms = 0.18 if rms > 0: audio = audio * (target_rms / rms) audio = np.clip(audio, -0.95, 0.95) # Strong silence trim energy = np.abs(audio) if len(energy) > 3000: threshold = np.max(energy) * 0.012 non_silent = np.where(energy > threshold)[0] if len(non_silent) > 400: start = max(0, non_silent[0] - 1000) end = min(len(audio), non_silent[-1] + 1000) audio = audio[start:end] print( f"[DEBUG] RMS after norm: {np.sqrt(np.mean(audio**2)):.4f} | Duration: {len(audio) / sr:.2f}s" ) # Pad duration_samples = len(audio) target_len = min(int(np.ceil(duration_samples / sr * 1.15) * sr), 16000 * 12) padded = np.zeros(target_len, dtype=np.float32) padded[:duration_samples] = audio input_audio = padded[np.newaxis, :] input_lengths = jnp.array([duration_samples], dtype=jnp.int32) start = time.time() logits, output_lengths = jit_inference(MODEL, input_audio, input_lengths) logits.block_until_ready() seq_len = int(output_lengths[0]) predicted_ids = jnp.argmax(logits[0, :seq_len], axis=-1) # === IMPORTANT DEBUG === blank_count = jnp.sum(predicted_ids == TOKENIZER.blank_id) blank_ratio = float(blank_count) / len(predicted_ids) non_blank = len(predicted_ids) - int(blank_count) print( f"[DEBUG] Frames after subsampling: {seq_len} | Non-blank tokens: {non_blank} ({blank_ratio:.1%} blank)" ) text = ctc_decode(predicted_ids, TOKENIZER) return text or "", f"{time.time() - start:.3f}s" # Discover files from test_audio TEST_AUDIO_DIR = Path("test_audio") AUDIO_FILES = sorted(list(TEST_AUDIO_DIR.glob("*.flac"))) FILE_OPTIONS = [f.name for f in AUDIO_FILES] FILE_MAP = {f.name: f for f in AUDIO_FILES} def transcribe_by_filename(filename): if not filename or filename not in FILE_MAP: return "Please select a valid file.", "0.000s" filepath = FILE_MAP[filename] # Load audio at model's sampling rate audio, sr = librosa.load(filepath, sr=FEAT_CONFIG.sampling_rate) # Reuse the core transcription logic return transcribe((sr, audio)) def update_preview(filename): if not filename or filename not in FILE_MAP: return None return str(FILE_MAP[filename]) # Update pre-compilation for dynamic sizes print("Pre-compiling...") for sec in [3, 4, 5, 6, 8, 12]: dummy_len = sec * 16000 dummy_audio = np.zeros((1, dummy_len), dtype=np.float32) dummy_lengths = jnp.array([int(dummy_len * 0.8)], dtype=jnp.int32) # 80% filled logits, _ = jit_inference(MODEL, dummy_audio, dummy_lengths) logits.block_until_ready() print(f" Compiled {sec}s bucket") # Gradio Interface with Blocks for premium look with gr.Blocks( theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"), title="TinyVoice: Georgian ASR", ) as iface: gr.Markdown( """ # 🎙️ TinyVoice: Georgian ASR Explorer """ ) with gr.Row(): with gr.Column(variant="panel"): gr.Markdown("### 🛠️ Audio Selection") audio_dropdown = gr.Dropdown( choices=FILE_OPTIONS, label="Search Audio Files", info=f"Listing {len(FILE_OPTIONS)} files from test_audio/", value=FILE_OPTIONS[0] if FILE_OPTIONS else None, filterable=True, ) audio_preview = gr.Audio( value=str(AUDIO_FILES[0]) if AUDIO_FILES else None, label="Audio Preview", interactive=False, ) transcribe_btn = gr.Button( "🚀 Transcribe Now", variant="primary", size="lg" ) with gr.Column(variant="panel"): gr.Markdown("### 📝 Results") output_text = gr.Textbox( label="Transcription", lines=8, placeholder="Transcription will appear here...", ) output_time = gr.Textbox(label="Inference Time") # Event Handlers audio_dropdown.change( fn=update_preview, inputs=audio_dropdown, outputs=audio_preview ) transcribe_btn.click( fn=transcribe_by_filename, inputs=audio_dropdown, outputs=[output_text, output_time], ) if __name__ == "__main__": iface.launch(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"))