import gradio as gr import os import shutil import tarfile import torch import soundfile as sf import torchaudio.functional as F from huggingface_hub import snapshot_download from omegaconf import OmegaConf from nemo.collections.asr.models import ASRModel # Pointing to your cloned model repository MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct" print("1. Downloading unzipped model files from Hugging Face...") model_dir = snapshot_download(repo_id=MODEL_NAME) print("2. Patching the configuration...") config_path = os.path.join(model_dir, "model_config.yaml") config = OmegaConf.load(config_path) # Allow modifications to the config object OmegaConf.set_struct(config, False) # Patch 1: Remove PyTorch 2.0 SDPA keys for NeMo 1.23 compatibility if 'encoder' in config: config.encoder.pop('use_pytorch_sdpa', None) config.encoder.pop('use_pytorch_sdpa_backends', None) # Patch 2 & 3: Fix decoding strategy and confidence config if 'decoding' in config: # Patch 2: Downgrade 'greedy_batch' strategy to 'greedy' if config.decoding.get('strategy') == 'greedy_batch': print(" -> Downgrading decoding strategy from 'greedy_batch' to 'greedy'") config.decoding.strategy = 'greedy' # Patch 3: Remove 'tdt_include_duration' which NeMo 1.23 doesn't recognize if 'confidence_cfg' in config.decoding and config.decoding.confidence_cfg is not None: if 'tdt_include_duration' in config.decoding.confidence_cfg: print(" -> Removing 'tdt_include_duration' from confidence_cfg") config.decoding.confidence_cfg.pop('tdt_include_duration', None) print("3. Packaging files into a standard .nemo archive...") patched_dir = "patched_nemo_env" os.makedirs(patched_dir, exist_ok=True) # Copy all the raw model files from the HF cache to our working directory for item in os.listdir(model_dir): s = os.path.join(model_dir, item) d = os.path.join(patched_dir, item) if os.path.isfile(s): shutil.copy2(s, d) # Overwrite the copied config with our cleaned version OmegaConf.save(config, os.path.join(patched_dir, "model_config.yaml")) # Tar it up into a standard .nemo file format that NeMo expects nemo_filepath = "patched_model.nemo" with tarfile.open(nemo_filepath, "w") as tar: for item in os.listdir(patched_dir): tar.add(os.path.join(patched_dir, item), arcname=item) print("4. Restoring model from patched .nemo file (this may take a moment)...") model = ASRModel.restore_from(restore_path=nemo_filepath) model.eval() print("Model loaded successfully!") def transcribe(audio_path): if not audio_path: return "Please upload or record audio." try: # 1. Load file using soundfile to completely bypass torchcodec bugs data, sample_rate = sf.read(audio_path) waveform = torch.from_numpy(data).float() # 2. Reshape soundfile format [time, channels] to torchaudio format [channels, time] if waveform.ndim == 1: waveform = waveform.unsqueeze(0) # Mono: [time] -> [1, time] else: waveform = waveform.transpose(0, 1) # Stereo: [time, channels] -> [channels, time] # 3. Convert to Mono if stereo if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # 4. Resample to 16000 Hz if necessary if sample_rate != 16000: waveform = F.resample(waveform, sample_rate, 16000) sample_rate = 16000 # 5. Write the file back out using soundfile processed_path = audio_path + "_mono_16k.wav" # soundfile expects mono arrays to be flat 1D: [time] flat_numpy_waveform = waveform.squeeze(0).numpy() sf.write(processed_path, flat_numpy_waveform, 16000) # 6. Pass to NeMo model transcription = model.transcribe(paths2audio_files=[processed_path])[0] if isinstance(transcription, list): return transcription[0] return transcription except Exception as e: return f"Error during transcription: {str(e)}" # Define the Gradio Interface with gr.Blocks(title="Malayalam FastConformer ASR") as demo: gr.Markdown("# 🎙️ Malayalam FastConformer Speech-to-Text") gr.Markdown("Upload an audio file or record from your microphone to generate a transcription.") with gr.Row(): with gr.Column(): audio_input = gr.Audio(type="filepath", label="Input Audio") transcribe_btn = gr.Button("Transcribe", variant="primary") with gr.Column(): text_output = gr.Textbox(label="Transcription", lines=5) transcribe_btn.click( fn=transcribe,