| import gradio as gr |
| import librosa |
| import os |
| import logging |
| from pathlib import Path |
| import torch |
| from transformers import Wav2Vec2ForCTC, AutoProcessor |
| import numpy as np |
| import spaces |
|
|
| |
| logging.basicConfig(level=logging.DEBUG) |
| logger = logging.getLogger(__name__) |
|
|
| ASR_SAMPLING_RATE = 16_000 |
|
|
| ASR_LANGUAGES = {} |
| with open(f"data/asr/all_langs.tsv") as f: |
| for line in f: |
| iso, name = line.split(" ", 1) |
| ASR_LANGUAGES[iso.strip()] = name.strip() |
|
|
| MODEL_ID = "facebook/mms-1b-all" |
|
|
| processor = AutoProcessor.from_pretrained(MODEL_ID) |
| model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) |
|
|
| def safe_process_file(file_obj): |
| try: |
| logger.debug(f"Processing file: {file_obj}") |
| |
| |
| file_path = Path(file_obj) |
| |
| logger.debug(f"Loading audio from file path: {file_path}") |
| |
| |
| audio_samples, sr = librosa.load(str(file_path), sr=ASR_SAMPLING_RATE, mono=True) |
| |
| safe_name = f"audio_{file_path.stem}.wav" |
| logger.debug(f"File processed successfully: {safe_name}") |
| return audio_samples, sr, safe_name |
| except Exception as e: |
| logger.error(f"Error processing file {getattr(file_obj, 'name', 'unknown')}: {str(e)}") |
| raise |
|
|
| def transcribe_multiple_files(audio_files, lang, transcription): |
| transcriptions = [] |
| |
| try: |
| audio_samples, sr, safe_name = safe_process_file(audio_files) |
| logger.debug(f"Transcribing file {audio_files}: {safe_name}") |
| logger.debug(f"Language selected: {lang}") |
| logger.debug(f"User-provided transcription: {transcription}") |
| |
| result = transcribe_file(model, audio_samples, lang, transcription) |
| logger.debug(f"Transcription result: {result}") |
|
|
| |
| transcriptions.append(f"File: {safe_name}\nTranscription: {result}\n") |
| except Exception as e: |
| logger.error(f"Error in transcription process: {str(e)}") |
| transcriptions.append(f"Error processing file: {str(e)}\n") |
| return "\n".join(transcriptions) |
|
|
| @spaces.GPU |
| def transcribe_file(model, audio_samples, lang, user_transcription): |
| |
| |
| |
| lang_code = lang.split()[0] |
| processor.tokenizer.set_target_lang(lang_code) |
| model.load_adapter(lang_code) |
|
|
| inputs = processor( |
| audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt" |
| ) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| model.to(device) |
| inputs = inputs.to(device) |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs).logits |
|
|
| ids = torch.argmax(outputs, dim=-1)[0] |
| transcription = processor.decode(ids) |
|
|
| |
| |
| |
| |
|
|
| return transcription |
|
|
| @spaces.GPU |
| def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code): |
| |
| transcription_tensor = processor.tokenize(user_transcription, return_tensors="pt") |
|
|
| |
| dataset = [(audio_samples, transcription_tensor)] |
|
|
| |
| data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) |
|
|
| |
| model.train() |
|
|
| |
| criterion = torch.nn.CTCLoss() |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
|
|
| |
| for epoch in range(5): |
| for batch in data_loader: |
| audio, transcription = batch |
| audio = audio.to(device) |
| transcription = transcription.to(device) |
|
|
| |
| inputs = processor(audio, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt") |
| outputs = model(**inputs).logits |
|
|
| loss = criterion(outputs, transcription["input_ids"]) |
|
|
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| |
| model.eval() |
|
|
| return model |
|
|
| ASR_EXAMPLES = [ |
| ["upload/english.mp3", "eng (English)"], |
| |
| |
| ] |
|
|
| ASR_NOTE = """ |
| The above demo doesn't use beam-search decoding using a language model. |
| Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy. |
| """ |