File size: 4,765 Bytes
14c0892 b02797a cd72718 4e887cc cd72718 c956577 14c0892 e09c87a e3a7bb9 14c0892 c956577 b02797a c956577 4fbb118 c956577 4fbb118 c956577 e09c87a 4fbb118 e09c87a 4fbb118 e09c87a 4fbb118 b02797a 14c0892 4e887cc cd72718 4e887cc cd72718 4e887cc cd72718 4e887cc cd72718 4e887cc cd72718 4e887cc cd72718 14c0892 4e887cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import gradio as gr
import os
import shutil
import tarfile
import torch
import soundfile as sf
import torchaudio.functional as F
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from nemo.collections.asr.models import ASRModel
# Pointing to your cloned model repository
MODEL_NAME = "trysem/stt_ml_fastconformer_ctc_med_punct"
print("1. Downloading unzipped model files from Hugging Face...")
model_dir = snapshot_download(repo_id=MODEL_NAME)
print("2. Patching the configuration...")
config_path = os.path.join(model_dir, "model_config.yaml")
config = OmegaConf.load(config_path)
# Allow modifications to the config object
OmegaConf.set_struct(config, False)
# Patch 1: Remove PyTorch 2.0 SDPA keys for NeMo 1.23 compatibility
if 'encoder' in config:
config.encoder.pop('use_pytorch_sdpa', None)
config.encoder.pop('use_pytorch_sdpa_backends', None)
# Patch 2 & 3: Fix decoding strategy and confidence config
if 'decoding' in config:
# Patch 2: Downgrade 'greedy_batch' strategy to 'greedy'
if config.decoding.get('strategy') == 'greedy_batch':
print(" -> Downgrading decoding strategy from 'greedy_batch' to 'greedy'")
config.decoding.strategy = 'greedy'
# Patch 3: Remove 'tdt_include_duration' which NeMo 1.23 doesn't recognize
if 'confidence_cfg' in config.decoding and config.decoding.confidence_cfg is not None:
if 'tdt_include_duration' in config.decoding.confidence_cfg:
print(" -> Removing 'tdt_include_duration' from confidence_cfg")
config.decoding.confidence_cfg.pop('tdt_include_duration', None)
print("3. Packaging files into a standard .nemo archive...")
patched_dir = "patched_nemo_env"
os.makedirs(patched_dir, exist_ok=True)
# Copy all the raw model files from the HF cache to our working directory
for item in os.listdir(model_dir):
s = os.path.join(model_dir, item)
d = os.path.join(patched_dir, item)
if os.path.isfile(s):
shutil.copy2(s, d)
# Overwrite the copied config with our cleaned version
OmegaConf.save(config, os.path.join(patched_dir, "model_config.yaml"))
# Tar it up into a standard .nemo file format that NeMo expects
nemo_filepath = "patched_model.nemo"
with tarfile.open(nemo_filepath, "w") as tar:
for item in os.listdir(patched_dir):
tar.add(os.path.join(patched_dir, item), arcname=item)
print("4. Restoring model from patched .nemo file (this may take a moment)...")
model = ASRModel.restore_from(restore_path=nemo_filepath)
model.eval()
print("Model loaded successfully!")
def transcribe(audio_path):
if not audio_path:
return "Please upload or record audio."
try:
# 1. Load file using soundfile to completely bypass torchcodec bugs
data, sample_rate = sf.read(audio_path)
waveform = torch.from_numpy(data).float()
# 2. Reshape soundfile format [time, channels] to torchaudio format [channels, time]
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0) # Mono: [time] -> [1, time]
else:
waveform = waveform.transpose(0, 1) # Stereo: [time, channels] -> [channels, time]
# 3. Convert to Mono if stereo
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 4. Resample to 16000 Hz if necessary
if sample_rate != 16000:
waveform = F.resample(waveform, sample_rate, 16000)
sample_rate = 16000
# 5. Write the file back out using soundfile
processed_path = audio_path + "_mono_16k.wav"
# soundfile expects mono arrays to be flat 1D: [time]
flat_numpy_waveform = waveform.squeeze(0).numpy()
sf.write(processed_path, flat_numpy_waveform, 16000)
# 6. Pass to NeMo model
transcription = model.transcribe(paths2audio_files=[processed_path])[0]
if isinstance(transcription, list):
return transcription[0]
return transcription
except Exception as e:
return f"Error during transcription: {str(e)}"
# Define the Gradio Interface
with gr.Blocks(title="Malayalam FastConformer ASR") as demo:
gr.Markdown("# 🎙️ Malayalam FastConformer Speech-to-Text")
gr.Markdown("Upload an audio file or record from your microphone to generate a transcription.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath", label="Input Audio")
transcribe_btn = gr.Button("Transcribe", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Transcription", lines=5)
transcribe_btn.click(
fn=transcribe, |