Spaces:
Running
Running
| # Imports | |
| from pathlib import Path | |
| import tempfile | |
| import os | |
| import gradio as gr | |
| import librosa | |
| import tgt.core | |
| import tgt.io3 | |
| import soundfile as sf | |
| from transformers import pipeline | |
| # Constants | |
| TEXTGRID_DIR = tempfile.mkdtemp() | |
| DEFAULT_MODEL = "ginic/data_seed_bs64_4_wav2vec2-large-xlsr-53-buckeye-ipa" | |
| TEXTGRID_DOWNLOAD_TEXT = "Download TextGrid file" | |
| TEXTGRID_NAME_INPUT_LABEL = "TextGrid file name" | |
| # Selection of models | |
| VALID_MODELS = [ | |
| "ctaguchi/wav2vec2-large-xlsr-japlmthufielta-ipa1000-ns", | |
| "ctaguchi/wav2vec2-large-xlsr-japlmthufielta-ipa-plus-2000", | |
| "ginic/data_seed_bs64_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/data_seed_bs64_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/data_seed_bs64_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/data_seed_bs64_4_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_30_female_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_30_female_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_30_female_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_30_female_4_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_30_female_5_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_70_female_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_70_female_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_70_female_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_70_female_4_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/gender_split_70_female_5_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/vary_individuals_old_only_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/vary_individuals_old_only_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/vary_individuals_old_only_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/vary_individuals_young_only_1_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/vary_individuals_young_only_2_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| "ginic/vary_individuals_young_only_3_wav2vec2-large-xlsr-53-buckeye-ipa", | |
| ] | |
| def load_model_and_predict( | |
| model_name: str, | |
| audio_in: str, | |
| model_state: dict, | |
| ): | |
| if audio_in is None: | |
| return ( | |
| "", | |
| model_state, | |
| gr.Textbox(label=TEXTGRID_NAME_INPUT_LABEL, interactive=False), | |
| ) | |
| if model_state["model_name"] != model_name: | |
| model_state = { | |
| "loaded_model": pipeline( | |
| task="automatic-speech-recognition", model=model_name | |
| ), | |
| "model_name": model_name, | |
| } | |
| prediction = model_state["loaded_model"](audio_in)["text"] | |
| return ( | |
| prediction, | |
| model_state, | |
| gr.Textbox( | |
| label=TEXTGRID_NAME_INPUT_LABEL, | |
| interactive=True, | |
| value=Path(audio_in).with_suffix(".TextGrid").name, | |
| ), | |
| ) | |
| def get_textgrid_contents(audio_in, textgrid_tier_name, transcription_prediction): | |
| if audio_in is None or transcription_prediction is None: | |
| return "" | |
| duration = librosa.get_duration(path=audio_in) | |
| annotation = tgt.core.Interval(0, duration, transcription_prediction) | |
| transcription_tier = tgt.core.IntervalTier( | |
| start_time=0, end_time=duration, name=textgrid_tier_name | |
| ) | |
| transcription_tier.add_annotation(annotation) | |
| textgrid = tgt.core.TextGrid() | |
| textgrid.add_tier(transcription_tier) | |
| return tgt.io3.export_to_long_textgrid(textgrid) | |
| def write_textgrid(textgrid_contents, textgrid_filename): | |
| """Writes the text grid contents to a named file in the temporary directory. | |
| Returns the path for download. | |
| """ | |
| textgrid_path = Path(TEXTGRID_DIR) / Path(textgrid_filename).name | |
| textgrid_path.write_text(textgrid_contents) | |
| return textgrid_path | |
| def get_interactive_download_button(textgrid_contents, textgrid_filename): | |
| return gr.DownloadButton( | |
| label=TEXTGRID_DOWNLOAD_TEXT, | |
| variant="primary", | |
| interactive=True, | |
| value=write_textgrid(textgrid_contents, textgrid_filename), | |
| ) | |
| def transcribe_intervals(audio_in, textgrid_path, source_tier, target_tier, model_state): | |
| if audio_in is None or textgrid_path is None: | |
| return "Missing audio or TextGrid input file." | |
| tg=tgt.io.read_textgrid(textgrid_path.name) | |
| tier = tg.get_tier_by_name(source_tier) | |
| ipa_tier = tgt.core.IntervalTier(name=target_tier) | |
| for interval in tier.intervals: | |
| if not interval.text.strip(): # Skip empty text intervals | |
| continue | |
| start, end = interval.start_time, interval.end_time | |
| try: | |
| y, sr = librosa.load(audio_in, sr=None, offset=start, duration=end-start) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
| sf.write(temp_audio.name, y, sr) | |
| prediction = model_state["loaded_model"](temp_audio.name)["text"] | |
| ipa_tier.add_annotation(tgt.core.Interval(start, end, prediction)) | |
| os.remove(temp_audio.name) | |
| except Exception as e: | |
| ipa_tier.add_annotation(tgt.core.Interval(start, end, f"[Error]: {str(e)}")) | |
| tg.add_tier(ipa_tier) | |
| tgt_str = tgt.io3.export_to_long_textgrid(tg) | |
| return tgt_str | |
| def extract_tier_names(textgrid_file): | |
| try: | |
| tg = tgt.io.read_textgrid(textgrid_file.name) | |
| tier_names = [tier.name for tier in tg.tiers] | |
| return gr.update(choices=tier_names, value=tier_names[0] if tier_names else None) | |
| except Exception as e: | |
| return gr.update(choices=[], value=None) | |
| def launch_demo(): | |
| initial_model = { | |
| "loaded_model": pipeline( | |
| task="automatic-speech-recognition", model=DEFAULT_MODEL | |
| ), | |
| "model_name": DEFAULT_MODEL, | |
| } | |
| # Helper function - enables the interval transcribe button | |
| def enable_interval_transcribe_btn(audio, textgrid): | |
| return gr.update(interactive=(audio is not None and textgrid is not None)) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""# Automatic International Phonetic Alphabet Transcription | |
| This demo allows you to experiment with producing phonetic transcriptions of uploaded or recorded audio using a selected automatic speech recognition (ASR) model.""") | |
| # Dropdown for model selection | |
| model_name = gr.Dropdown( | |
| VALID_MODELS, | |
| value=DEFAULT_MODEL, | |
| label="IPA transcription ASR model", | |
| info="Select the model to use for prediction.", | |
| ) | |
| # Dropdown for transcription type selection | |
| transcription_type = gr.Dropdown( | |
| choices=["Full Audio", "Interval"], | |
| label="Transcription Type", | |
| value=None, | |
| interactive=True, | |
| ) | |
| model_state = gr.State(value=initial_model) | |
| # Full audio transcription section | |
| with gr.Column(visible=False) as full_audio_section: | |
| full_audio = gr.Audio(type="filepath", show_download_button=True, label="Upload Audio File") | |
| full_transcribe_btn = gr.Button("Transcribe Full Audio", interactive=False, variant="primary") | |
| full_prediction = gr.Textbox(label="IPA Transcription", show_copy_button=True) | |
| full_textgrid_tier = gr.Textbox(label="TextGrid Tier Name", value="transcription", interactive=True) | |
| full_textgrid_filename = gr.Textbox(label=TEXTGRID_NAME_INPUT_LABEL, interactive=False) | |
| full_textgrid_contents = gr.Textbox(label="TextGrid Contents", show_copy_button=True) | |
| full_download_btn = gr.DownloadButton(label=TEXTGRID_DOWNLOAD_TEXT, interactive=False, variant="primary") | |
| full_reset_btn = gr.Button("Reset", variant="secondary") | |
| # Interval transcription section | |
| with gr.Column(visible=False) as interval_section: | |
| interval_audio = gr.Audio(type="filepath", show_download_button=True, label="Upload Audio File") | |
| interval_textgrid_file = gr.File(file_types=[".TextGrid"], label="Upload TextGrid File") | |
| tier_names = gr.Dropdown(label="Source Tier (existing)", choices=[], interactive=True) | |
| target_tier = gr.Textbox(label="Target Tier (new)", value="IPATier", placeholder="e.g. IPATier") | |
| interval_transcribe_btn = gr.Button("Transcribe Intervals", interactive=False, variant="primary") | |
| interval_result = gr.Textbox(label="IPA Interval Transcription", show_copy_button=True, interactive=False) | |
| interval_download_btn = gr.DownloadButton(label=TEXTGRID_DOWNLOAD_TEXT, interactive=False, variant="primary") | |
| interval_reset_btn = gr.Button("Reset", variant="secondary") | |
| # Section visibility toggle | |
| transcription_type.change( | |
| fn=lambda t: ( | |
| gr.update(visible=t == "Full Audio"), | |
| gr.update(visible=t == "Interval"), | |
| ), | |
| inputs=transcription_type, | |
| outputs=[full_audio_section, interval_section], | |
| ) | |
| # Enable full transcribe button after audio uploaded | |
| full_audio.change( | |
| fn=lambda audio: gr.update(interactive=audio is not None), | |
| inputs=full_audio, | |
| outputs=full_transcribe_btn, | |
| ) | |
| # Full transcription logic | |
| full_transcribe_btn.click( | |
| fn=load_model_and_predict, | |
| inputs=[model_name, full_audio, model_state], | |
| outputs=[full_prediction, model_state, full_textgrid_filename], | |
| ) | |
| full_prediction.change( | |
| fn=get_textgrid_contents, | |
| inputs=[full_audio, full_textgrid_tier, full_prediction], | |
| outputs=[full_textgrid_contents], | |
| ) | |
| full_textgrid_contents.change( | |
| fn=get_interactive_download_button, | |
| inputs=[full_textgrid_contents, full_textgrid_filename], | |
| outputs=[full_download_btn], | |
| ) | |
| full_reset_btn.click( | |
| fn=lambda: (None, "", "", "", gr.update(interactive=False)), | |
| outputs=[full_audio, full_prediction, full_textgrid_filename, full_textgrid_contents, full_download_btn], | |
| ) | |
| # Enable interval transcribe button only when both files are uploaded | |
| interval_audio.change( | |
| fn=enable_interval_transcribe_btn, | |
| inputs=[interval_audio, interval_textgrid_file], | |
| outputs=[interval_transcribe_btn], | |
| ) | |
| interval_textgrid_file.change( | |
| fn=enable_interval_transcribe_btn, | |
| inputs=[interval_audio, interval_textgrid_file], | |
| outputs=[interval_transcribe_btn], | |
| ) | |
| # Interval logic | |
| interval_textgrid_file.change( | |
| fn=extract_tier_names, | |
| inputs=[interval_textgrid_file], | |
| outputs=[tier_names], | |
| ) | |
| interval_transcribe_btn.click( | |
| fn=transcribe_intervals, | |
| inputs=[interval_audio, interval_textgrid_file, tier_names, target_tier, model_state], | |
| outputs=[interval_result], | |
| ) | |
| interval_result.change( | |
| fn=lambda tg_text: gr.update(value=write_textgrid(tg_text, "interval_output.TextGrid"), interactive=True), | |
| inputs=[interval_result], | |
| outputs=[interval_download_btn], | |
| ) | |
| interval_reset_btn.click( | |
| fn=lambda: (None, None, gr.update(choices=[]), "IPATier", "", gr.update(interactive=False)), | |
| outputs=[interval_audio, interval_textgrid_file, tier_names, target_tier, interval_result, interval_download_btn], | |
| ) | |
| demo.launch(max_file_size="100mb") | |
| if __name__ == "__main__": | |
| launch_demo() | |