KASR / app.py
trysem's picture
Update app.py
4e887cc verified
Raw
History Blame Contribute Delete
4.77 kB
import gradio as gr
import os
import shutil
import tarfile
import torch
import soundfile as sf
import torchaudio.functional as F
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from nemo.collections.asr.models import ASRModel
# Pointing to your cloned model repository
MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct"
print("1. Downloading unzipped model files from Hugging Face...")
model_dir = snapshot_download(repo_id=MODEL_NAME)
print("2. Patching the configuration...")
config_path = os.path.join(model_dir, "model_config.yaml")
config = OmegaConf.load(config_path)
# Allow modifications to the config object
OmegaConf.set_struct(config, False)
# Patch 1: Remove PyTorch 2.0 SDPA keys for NeMo 1.23 compatibility
if 'encoder' in config:
config.encoder.pop('use_pytorch_sdpa', None)
config.encoder.pop('use_pytorch_sdpa_backends', None)
# Patch 2 & 3: Fix decoding strategy and confidence config
if 'decoding' in config:
# Patch 2: Downgrade 'greedy_batch' strategy to 'greedy'
if config.decoding.get('strategy') == 'greedy_batch':
print(" -> Downgrading decoding strategy from 'greedy_batch' to 'greedy'")
config.decoding.strategy = 'greedy'
# Patch 3: Remove 'tdt_include_duration' which NeMo 1.23 doesn't recognize
if 'confidence_cfg' in config.decoding and config.decoding.confidence_cfg is not None:
if 'tdt_include_duration' in config.decoding.confidence_cfg:
print(" -> Removing 'tdt_include_duration' from confidence_cfg")
config.decoding.confidence_cfg.pop('tdt_include_duration', None)
print("3. Packaging files into a standard .nemo archive...")
patched_dir = "patched_nemo_env"
os.makedirs(patched_dir, exist_ok=True)
# Copy all the raw model files from the HF cache to our working directory
for item in os.listdir(model_dir):
s = os.path.join(model_dir, item)
d = os.path.join(patched_dir, item)
if os.path.isfile(s):
shutil.copy2(s, d)
# Overwrite the copied config with our cleaned version
OmegaConf.save(config, os.path.join(patched_dir, "model_config.yaml"))
# Tar it up into a standard .nemo file format that NeMo expects
nemo_filepath = "patched_model.nemo"
with tarfile.open(nemo_filepath, "w") as tar:
for item in os.listdir(patched_dir):
tar.add(os.path.join(patched_dir, item), arcname=item)
print("4. Restoring model from patched .nemo file (this may take a moment)...")
model = ASRModel.restore_from(restore_path=nemo_filepath)
model.eval()
print("Model loaded successfully!")
def transcribe(audio_path):
if not audio_path:
return "Please upload or record audio."
try:
# 1. Load file using soundfile to completely bypass torchcodec bugs
data, sample_rate = sf.read(audio_path)
waveform = torch.from_numpy(data).float()
# 2. Reshape soundfile format [time, channels] to torchaudio format [channels, time]
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0) # Mono: [time] -> [1, time]
else:
waveform = waveform.transpose(0, 1) # Stereo: [time, channels] -> [channels, time]
# 3. Convert to Mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 4. Resample to 16000 Hz if necessary
if sample_rate != 16000:
waveform = F.resample(waveform, sample_rate, 16000)
sample_rate = 16000
# 5. Write the file back out using soundfile
processed_path = audio_path + "_mono_16k.wav"
# soundfile expects mono arrays to be flat 1D: [time]
flat_numpy_waveform = waveform.squeeze(0).numpy()
sf.write(processed_path, flat_numpy_waveform, 16000)
# 6. Pass to NeMo model
transcription = model.transcribe(paths2audio_files=[processed_path])[0]
if isinstance(transcription, list):
return transcription[0]
return transcription
except Exception as e:
return f"Error during transcription: {str(e)}"
# Define the Gradio Interface
with gr.Blocks(title="Malayalam FastConformer ASR") as demo:
gr.Markdown("# 🎙️ Malayalam FastConformer Speech-to-Text")
gr.Markdown("Upload an audio file or record from your microphone to generate a transcription.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath", label="Input Audio")
transcribe_btn = gr.Button("Transcribe", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Transcription", lines=5)
transcribe_btn.click(
fn=transcribe,