Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoProcessor, VoxtralForConditionalGeneration | |
| import spaces | |
| #### Functions | |
| def process_transcript(language: str, audio_path: str) -> str: | |
| """Process the audio file to return its transcription. | |
| Args: | |
| language: The language of the audio. | |
| audio_path: The path to the audio file. | |
| Returns: | |
| The transcribed text of the audio. | |
| """ | |
| if audio_path is None: | |
| return "Please provide some input audio: either upload an audio file or use the microphone." | |
| else: | |
| id_language = dict_languages[language] | |
| inputs = processor.apply_transcrition_request(language=id_language, audio=audio_path, model_id=model_name) | |
| inputs = inputs.to(device, dtype=torch.bfloat16) | |
| outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS) | |
| decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return decoded_outputs[0] | |
| ### | |
| def process_translate(language: str, audio_path: str) -> str: | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "audio", | |
| "path": audio_path, | |
| }, | |
| {"type": "text", "text": "Translate this in "+language}, | |
| ], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template(conversation) | |
| inputs = inputs.to(device, dtype=torch.bfloat16) | |
| outputs = model.generate(**inputs, max_new_tokens=MAX_TOKENS) | |
| decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return decoded_outputs | |
| def disable_buttons(): | |
| return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
| def enable_buttons(): | |
| return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False) | |
| ### | |
| ### Initializations | |
| MAX_TOKENS = 32000 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"*** Device: {device}") | |
| model_name = 'mistralai/Voxtral-Mini-3B-2507' | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| model = VoxtralForConditionalGeneration.from_pretrained(model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map=device) | |
| # Supported languages | |
| dict_languages = {"English": "en", | |
| "French": "fr", | |
| "German": "de", | |
| "Spanish": "es", | |
| "Italian": "it", | |
| "Portuguese": "pt", | |
| "Dutch": "nl", | |
| "Hindi": "hi"} | |
| #### Gradio interface | |
| with gr.Blocks(title="Transcription") as audio: | |
| gr.Markdown("# Voxtral Mini Evaluation") | |
| gr.Markdown("#### Choose the language of the audio and set an audio file to process it.") | |
| gr.Markdown("##### *(Voxtral handles audios up to 30 minutes for transcription)*") | |
| with gr.Row(): | |
| with gr.Column(): | |
| sel_language = gr.Dropdown( | |
| choices=list(dict_languages.keys()), | |
| value="English", | |
| label="Select the language of the audio file:" | |
| ) | |
| with gr.Column(): | |
| sel_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", | |
| label="Upload an audio file, record via microphone, or select a demo file:") | |
| example = [["mapo_tofu.mp3"]] | |
| gr.Examples( | |
| examples=example, | |
| inputs=sel_audio, | |
| outputs=None, | |
| fn=None, | |
| cache_examples=False, | |
| run_on_click=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| submit_transcript = gr.Button("Extract transcription", variant="primary") | |
| text_transcript = gr.Textbox(label="Generated transcription", lines=10) | |
| with gr.Column(): | |
| sel_translate_language = gr.Dropdown( | |
| choices=list(dict_languages.keys()), | |
| value="English", | |
| label="Select the language for translation:" | |
| ) | |
| submit_translate = gr.Button("Translate audio file", variant="primary") | |
| text_translate = gr.Textbox(label="Generated translation", lines=10) | |
| with gr.Column(): | |
| submit_chat = gr.Button("Ask audio file", variant="primary") | |
| text_chat = gr.Textbox(label="Model answer", lines=10) | |
| ### Processing | |
| # Transcription | |
| submit_transcript.click( | |
| disable_buttons, | |
| outputs=[submit_transcript, submit_translate, submit_chat], | |
| trigger_mode="once", | |
| ).then( | |
| fn=process_transcript, | |
| inputs=[sel_language, sel_audio], | |
| outputs=text_transcript | |
| ).then( | |
| enable_buttons, | |
| outputs=[submit_transcript, submit_translate, submit_chat], | |
| ) | |
| # Translation | |
| submit_transcript.click( | |
| disable_buttons, | |
| outputs=[submit_transcript, submit_translate, submit_chat], | |
| trigger_mode="once", | |
| ).then( | |
| fn=process_transcript, | |
| inputs=[sel_translate_language, sel_audio], | |
| outputs=text_transcript | |
| ).then( | |
| enable_buttons, | |
| outputs=[submit_transcript, submit_translate, submit_chat], | |
| ) | |
| ### Launch the app | |
| if __name__ == "__main__": | |
| audio.launch() | |