test-space / app.py
ZurabDz's picture
Track audio files with Git LFS and add Gradio examples
ee68216
import os
import time
import jax
import jax.numpy as jnp
import numpy as np
import librosa
import gradio as gr
from pathlib import Path
from flax import nnx
import orbax.checkpoint as ocp
os.environ["JAX_PLATFORMS"] = "cpu"
from conformer.tokenizer import Tokenizer
from conformer.config import ConformerConfig, FeaturizerConfig
from conformer.model import ZipformerEncoder
def load_model():
feat_config = FeaturizerConfig()
conf_config = ConformerConfig()
tokenizer = Tokenizer.load_tokenizer(Path("./checkpoints/tokenizer.pkl"))
token_count = (
tokenizer.vocab_size
if hasattr(tokenizer, "vocab_size")
else len(tokenizer.id_to_char)
)
model = ZipformerEncoder(
token_count=token_count,
num_layers=conf_config.num_encoder_layers,
d_model=conf_config.encoder_dim,
num_head=conf_config.num_attention_heads,
dropout=0.0,
feed_forward_expansion_factor=conf_config.feed_forward_expansion_factor,
d_input=feat_config.n_mels,
sample_rate=feat_config.sampling_rate,
n_fft=feat_config.n_fft,
n_window_size=feat_config.win_length,
n_window_stride=feat_config.hop_length,
dtype=jnp.float32,
rngs=nnx.Rngs(0),
)
checkpoint_dir = os.path.abspath("./checkpoints")
mngr = ocp.CheckpointManager(checkpoint_dir, options=ocp.CheckpointManagerOptions())
latest_step = mngr.latest_step()
if latest_step is None:
raise ValueError("No checkpoints found.")
print(f"Restoring model from step {latest_step}...")
restored = mngr.restore(
latest_step,
args=ocp.args.Composite(
model=ocp.args.StandardRestore(nnx.state(model)),
),
)
nnx.update(model, restored.model)
mngr.close()
return model, tokenizer, feat_config
print("Loading model...")
MODEL, TOKENIZER, FEAT_CONFIG = load_model()
MODEL.mel_spectogram.spec_augment = False
MODEL.mel_spectogram.normalize = True
def trim_silence(audio, threshold=0.01):
energy = np.abs(audio)
start = np.where(energy > threshold)[0][0]
end = np.where(energy > threshold)[0][-1]
return audio[start:end]
def ctc_decode(ids, tokenizer):
collapsed_ids = []
prev = -1
for _id in ids:
_id = int(_id)
if _id != prev:
if _id != tokenizer.blank_id and _id != getattr(
tokenizer, "padding_id", -1
):
collapsed_ids.append(_id)
prev = _id
return "".join([tokenizer.id_to_char.get(_id, f"[{_id}]") for _id in collapsed_ids])
@nnx.jit
def jit_inference(model: ZipformerEncoder, audio: jnp.ndarray, lengths: jnp.ndarray):
return model(audio, training=False, inputs_lengths=lengths)
def transcribe(audio_input):
if audio_input is None:
return "Please provide an audio file.", ""
sr, audio = audio_input
if audio.dtype.kind in "iu":
max_val = float(2 ** (8 * audio.itemsize - 1))
audio = audio.astype(np.float32) / max_val
else:
audio = audio.astype(np.float32)
# Resample
if sr != FEAT_CONFIG.sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=FEAT_CONFIG.sampling_rate)
sr = FEAT_CONFIG.sampling_rate
# === FIXES FOR MIC DOMAIN SHIFT ===
audio = audio - np.mean(audio) # DC removal
# Pre-emphasis (this is the #1 thing missing for real mics vs Common Voice)
if len(audio) > 1:
audio = np.append(audio[0], audio[1:] - 0.97 * audio[:-1])
# RMS normalization
rms = np.sqrt(np.mean(audio**2) + 1e-8)
target_rms = 0.18
if rms > 0:
audio = audio * (target_rms / rms)
audio = np.clip(audio, -0.95, 0.95)
# Strong silence trim
energy = np.abs(audio)
if len(energy) > 3000:
threshold = np.max(energy) * 0.012
non_silent = np.where(energy > threshold)[0]
if len(non_silent) > 400:
start = max(0, non_silent[0] - 1000)
end = min(len(audio), non_silent[-1] + 1000)
audio = audio[start:end]
print(
f"[DEBUG] RMS after norm: {np.sqrt(np.mean(audio**2)):.4f} | Duration: {len(audio) / sr:.2f}s"
)
# Pad
duration_samples = len(audio)
target_len = min(int(np.ceil(duration_samples / sr * 1.15) * sr), 16000 * 12)
padded = np.zeros(target_len, dtype=np.float32)
padded[:duration_samples] = audio
input_audio = padded[np.newaxis, :]
input_lengths = jnp.array([duration_samples], dtype=jnp.int32)
start = time.time()
logits, output_lengths = jit_inference(MODEL, input_audio, input_lengths)
logits.block_until_ready()
seq_len = int(output_lengths[0])
predicted_ids = jnp.argmax(logits[0, :seq_len], axis=-1)
# === IMPORTANT DEBUG ===
blank_count = jnp.sum(predicted_ids == TOKENIZER.blank_id)
blank_ratio = float(blank_count) / len(predicted_ids)
non_blank = len(predicted_ids) - int(blank_count)
print(
f"[DEBUG] Frames after subsampling: {seq_len} | Non-blank tokens: {non_blank} ({blank_ratio:.1%} blank)"
)
text = ctc_decode(predicted_ids, TOKENIZER)
return text or "<EMPTY>", f"{time.time() - start:.3f}s"
# Discover files from test_audio
TEST_AUDIO_DIR = Path("test_audio")
AUDIO_FILES = sorted(list(TEST_AUDIO_DIR.glob("*.flac")))
FILE_OPTIONS = [f.name for f in AUDIO_FILES]
FILE_MAP = {f.name: f for f in AUDIO_FILES}
def transcribe_by_filename(filename):
if not filename or filename not in FILE_MAP:
return "Please select a valid file.", "0.000s"
filepath = FILE_MAP[filename]
# Load audio at model's sampling rate
audio, sr = librosa.load(filepath, sr=FEAT_CONFIG.sampling_rate)
# Reuse the core transcription logic
return transcribe((sr, audio))
def update_preview(filename):
if not filename or filename not in FILE_MAP:
return None
return str(FILE_MAP[filename])
# Update pre-compilation for dynamic sizes
print("Pre-compiling...")
for sec in [3, 4, 5, 6, 8, 12]:
dummy_len = sec * 16000
dummy_audio = np.zeros((1, dummy_len), dtype=np.float32)
dummy_lengths = jnp.array([int(dummy_len * 0.8)], dtype=jnp.int32) # 80% filled
logits, _ = jit_inference(MODEL, dummy_audio, dummy_lengths)
logits.block_until_ready()
print(f" Compiled {sec}s bucket")
# Gradio Interface with Blocks for premium look
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
title="TinyVoice: Georgian ASR",
) as iface:
gr.Markdown(
"""
# πŸŽ™οΈ TinyVoice: Georgian ASR Explorer
"""
)
with gr.Row():
with gr.Column(variant="panel"):
gr.Markdown("### πŸ› οΈ Audio Selection")
audio_dropdown = gr.Dropdown(
choices=FILE_OPTIONS,
label="Search Audio Files",
info=f"Listing {len(FILE_OPTIONS)} files from test_audio/",
value=FILE_OPTIONS[0] if FILE_OPTIONS else None,
filterable=True,
)
audio_preview = gr.Audio(
value=str(AUDIO_FILES[0]) if AUDIO_FILES else None,
label="Audio Preview",
interactive=False,
)
transcribe_btn = gr.Button(
"πŸš€ Transcribe Now", variant="primary", size="lg"
)
with gr.Column(variant="panel"):
gr.Markdown("### πŸ“ Results")
output_text = gr.Textbox(
label="Transcription",
lines=8,
placeholder="Transcription will appear here...",
)
output_time = gr.Textbox(label="Inference Time")
# Event Handlers
audio_dropdown.change(
fn=update_preview, inputs=audio_dropdown, outputs=audio_preview
)
transcribe_btn.click(
fn=transcribe_by_filename,
inputs=audio_dropdown,
outputs=[output_text, output_time],
)
if __name__ == "__main__":
iface.launch(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"))