Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import librosa | |
| import gradio as gr | |
| from pathlib import Path | |
| from flax import nnx | |
| import orbax.checkpoint as ocp | |
| os.environ["JAX_PLATFORMS"] = "cpu" | |
| from conformer.tokenizer import Tokenizer | |
| from conformer.config import ConformerConfig, FeaturizerConfig | |
| from conformer.model import ZipformerEncoder | |
| def load_model(): | |
| feat_config = FeaturizerConfig() | |
| conf_config = ConformerConfig() | |
| tokenizer = Tokenizer.load_tokenizer(Path("./checkpoints/tokenizer.pkl")) | |
| token_count = ( | |
| tokenizer.vocab_size | |
| if hasattr(tokenizer, "vocab_size") | |
| else len(tokenizer.id_to_char) | |
| ) | |
| model = ZipformerEncoder( | |
| token_count=token_count, | |
| num_layers=conf_config.num_encoder_layers, | |
| d_model=conf_config.encoder_dim, | |
| num_head=conf_config.num_attention_heads, | |
| dropout=0.0, | |
| feed_forward_expansion_factor=conf_config.feed_forward_expansion_factor, | |
| d_input=feat_config.n_mels, | |
| sample_rate=feat_config.sampling_rate, | |
| n_fft=feat_config.n_fft, | |
| n_window_size=feat_config.win_length, | |
| n_window_stride=feat_config.hop_length, | |
| dtype=jnp.float32, | |
| rngs=nnx.Rngs(0), | |
| ) | |
| checkpoint_dir = os.path.abspath("./checkpoints") | |
| mngr = ocp.CheckpointManager(checkpoint_dir, options=ocp.CheckpointManagerOptions()) | |
| latest_step = mngr.latest_step() | |
| if latest_step is None: | |
| raise ValueError("No checkpoints found.") | |
| print(f"Restoring model from step {latest_step}...") | |
| restored = mngr.restore( | |
| latest_step, | |
| args=ocp.args.Composite( | |
| model=ocp.args.StandardRestore(nnx.state(model)), | |
| ), | |
| ) | |
| nnx.update(model, restored.model) | |
| mngr.close() | |
| return model, tokenizer, feat_config | |
| print("Loading model...") | |
| MODEL, TOKENIZER, FEAT_CONFIG = load_model() | |
| MODEL.mel_spectogram.spec_augment = False | |
| MODEL.mel_spectogram.normalize = True | |
| def trim_silence(audio, threshold=0.01): | |
| energy = np.abs(audio) | |
| start = np.where(energy > threshold)[0][0] | |
| end = np.where(energy > threshold)[0][-1] | |
| return audio[start:end] | |
| def ctc_decode(ids, tokenizer): | |
| collapsed_ids = [] | |
| prev = -1 | |
| for _id in ids: | |
| _id = int(_id) | |
| if _id != prev: | |
| if _id != tokenizer.blank_id and _id != getattr( | |
| tokenizer, "padding_id", -1 | |
| ): | |
| collapsed_ids.append(_id) | |
| prev = _id | |
| return "".join([tokenizer.id_to_char.get(_id, f"[{_id}]") for _id in collapsed_ids]) | |
| def jit_inference(model: ZipformerEncoder, audio: jnp.ndarray, lengths: jnp.ndarray): | |
| return model(audio, training=False, inputs_lengths=lengths) | |
| def transcribe(audio_input): | |
| if audio_input is None: | |
| return "Please provide an audio file.", "" | |
| sr, audio = audio_input | |
| if audio.dtype.kind in "iu": | |
| max_val = float(2 ** (8 * audio.itemsize - 1)) | |
| audio = audio.astype(np.float32) / max_val | |
| else: | |
| audio = audio.astype(np.float32) | |
| # Resample | |
| if sr != FEAT_CONFIG.sampling_rate: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=FEAT_CONFIG.sampling_rate) | |
| sr = FEAT_CONFIG.sampling_rate | |
| # === FIXES FOR MIC DOMAIN SHIFT === | |
| audio = audio - np.mean(audio) # DC removal | |
| # Pre-emphasis (this is the #1 thing missing for real mics vs Common Voice) | |
| if len(audio) > 1: | |
| audio = np.append(audio[0], audio[1:] - 0.97 * audio[:-1]) | |
| # RMS normalization | |
| rms = np.sqrt(np.mean(audio**2) + 1e-8) | |
| target_rms = 0.18 | |
| if rms > 0: | |
| audio = audio * (target_rms / rms) | |
| audio = np.clip(audio, -0.95, 0.95) | |
| # Strong silence trim | |
| energy = np.abs(audio) | |
| if len(energy) > 3000: | |
| threshold = np.max(energy) * 0.012 | |
| non_silent = np.where(energy > threshold)[0] | |
| if len(non_silent) > 400: | |
| start = max(0, non_silent[0] - 1000) | |
| end = min(len(audio), non_silent[-1] + 1000) | |
| audio = audio[start:end] | |
| print( | |
| f"[DEBUG] RMS after norm: {np.sqrt(np.mean(audio**2)):.4f} | Duration: {len(audio) / sr:.2f}s" | |
| ) | |
| # Pad | |
| duration_samples = len(audio) | |
| target_len = min(int(np.ceil(duration_samples / sr * 1.15) * sr), 16000 * 12) | |
| padded = np.zeros(target_len, dtype=np.float32) | |
| padded[:duration_samples] = audio | |
| input_audio = padded[np.newaxis, :] | |
| input_lengths = jnp.array([duration_samples], dtype=jnp.int32) | |
| start = time.time() | |
| logits, output_lengths = jit_inference(MODEL, input_audio, input_lengths) | |
| logits.block_until_ready() | |
| seq_len = int(output_lengths[0]) | |
| predicted_ids = jnp.argmax(logits[0, :seq_len], axis=-1) | |
| # === IMPORTANT DEBUG === | |
| blank_count = jnp.sum(predicted_ids == TOKENIZER.blank_id) | |
| blank_ratio = float(blank_count) / len(predicted_ids) | |
| non_blank = len(predicted_ids) - int(blank_count) | |
| print( | |
| f"[DEBUG] Frames after subsampling: {seq_len} | Non-blank tokens: {non_blank} ({blank_ratio:.1%} blank)" | |
| ) | |
| text = ctc_decode(predicted_ids, TOKENIZER) | |
| return text or "<EMPTY>", f"{time.time() - start:.3f}s" | |
| # Discover files from test_audio | |
| TEST_AUDIO_DIR = Path("test_audio") | |
| AUDIO_FILES = sorted(list(TEST_AUDIO_DIR.glob("*.flac"))) | |
| FILE_OPTIONS = [f.name for f in AUDIO_FILES] | |
| FILE_MAP = {f.name: f for f in AUDIO_FILES} | |
| def transcribe_by_filename(filename): | |
| if not filename or filename not in FILE_MAP: | |
| return "Please select a valid file.", "0.000s" | |
| filepath = FILE_MAP[filename] | |
| # Load audio at model's sampling rate | |
| audio, sr = librosa.load(filepath, sr=FEAT_CONFIG.sampling_rate) | |
| # Reuse the core transcription logic | |
| return transcribe((sr, audio)) | |
| def update_preview(filename): | |
| if not filename or filename not in FILE_MAP: | |
| return None | |
| return str(FILE_MAP[filename]) | |
| # Update pre-compilation for dynamic sizes | |
| print("Pre-compiling...") | |
| for sec in [3, 4, 5, 6, 8, 12]: | |
| dummy_len = sec * 16000 | |
| dummy_audio = np.zeros((1, dummy_len), dtype=np.float32) | |
| dummy_lengths = jnp.array([int(dummy_len * 0.8)], dtype=jnp.int32) # 80% filled | |
| logits, _ = jit_inference(MODEL, dummy_audio, dummy_lengths) | |
| logits.block_until_ready() | |
| print(f" Compiled {sec}s bucket") | |
| # Gradio Interface with Blocks for premium look | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"), | |
| title="TinyVoice: Georgian ASR", | |
| ) as iface: | |
| gr.Markdown( | |
| """ | |
| # ποΈ TinyVoice: Georgian ASR Explorer | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("### π οΈ Audio Selection") | |
| audio_dropdown = gr.Dropdown( | |
| choices=FILE_OPTIONS, | |
| label="Search Audio Files", | |
| info=f"Listing {len(FILE_OPTIONS)} files from test_audio/", | |
| value=FILE_OPTIONS[0] if FILE_OPTIONS else None, | |
| filterable=True, | |
| ) | |
| audio_preview = gr.Audio( | |
| value=str(AUDIO_FILES[0]) if AUDIO_FILES else None, | |
| label="Audio Preview", | |
| interactive=False, | |
| ) | |
| transcribe_btn = gr.Button( | |
| "π Transcribe Now", variant="primary", size="lg" | |
| ) | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("### π Results") | |
| output_text = gr.Textbox( | |
| label="Transcription", | |
| lines=8, | |
| placeholder="Transcription will appear here...", | |
| ) | |
| output_time = gr.Textbox(label="Inference Time") | |
| # Event Handlers | |
| audio_dropdown.change( | |
| fn=update_preview, inputs=audio_dropdown, outputs=audio_preview | |
| ) | |
| transcribe_btn.click( | |
| fn=transcribe_by_filename, | |
| inputs=audio_dropdown, | |
| outputs=[output_text, output_time], | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")) | |