Voxtral / app.py
MohamedRashad's picture
Add initial implementation of Voxtral audio processing app with Gradio interface
f0d5b79
raw
history blame
3.59 kB
from pathlib import Path
import gradio as gr
import spaces
import torch
from huggingface_hub import snapshot_download
from transformers import AutoProcessor, VoxtralForConditionalGeneration
# Model paths and setup
voxtral_mini_path = snapshot_download(
repo_id='mistralai/Voxtral-Mini-3B-2507',
revision='refs/pr/16',
local_dir=Path(__file__).parent / 'Voxtral-Mini-3B-2507',
resume_download=True,
)
print(f"Voxtral Mini model downloaded to: {voxtral_mini_path}")
voxtral_small_path = snapshot_download(
repo_id='mistralai/Voxtral-Small-24B-2507',
revision='refs/pr/9',
local_dir=Path(__file__).parent / 'Voxtral-Small-24B-2507',
resume_download=True,
)
print(f"Voxtral Small model downloaded to: {voxtral_small_path}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load model and processor
voxtral_mini_processor = AutoProcessor.from_pretrained(voxtral_mini_path)
voxtral_mini_model = VoxtralForConditionalGeneration.from_pretrained(voxtral_mini_path, torch_dtype=torch.bfloat16, device_map=device)
voxtral_small_processor = AutoProcessor.from_pretrained(voxtral_small_path)
voxtral_small_model = VoxtralForConditionalGeneration.from_pretrained(voxtral_small_path, torch_dtype=torch.bfloat16, device_map=device)
@spaces.GPU()
def process_audio(audio_path, model_name, language="en", max_tokens=500):
"""Process audio with selected Voxtral model and return the generated response"""
if not audio_path:
return "Please upload an audio file."
if model_name == "Voxtral Mini (3B)":
model = voxtral_mini_model
processor = voxtral_mini_processor
repo_id = str(voxtral_mini_path)
elif model_name == "Voxtral Small (24B)":
model = voxtral_small_model
processor = voxtral_small_processor
repo_id = str(voxtral_small_path)
else:
return "Invalid model selected."
inputs = processor.apply_transcription_request(language=language, audio=audio_path, model_id=repo_id)
inputs = inputs.to(device, dtype=torch.bfloat16)
outputs = model.generate(**inputs, max_new_tokens=max_tokens)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
return decoded_outputs[0]
# Define Gradio interface
with gr.Blocks(title="Voxtral Demo") as demo:
gr.Markdown("# Voxtral Audio Processing Demo")
gr.Markdown("Upload an audio file and get a transcription/response from Voxtral.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath", label="Upload Audio")
model_selector = gr.Dropdown(
choices=["Voxtral Mini (3B)", "Voxtral Small (24B)"],
value="Voxtral Mini (3B)",
label="Select Model"
)
language = gr.Dropdown(
choices=["en", "fr", "de", "es", "it", "pt", "nl", "ru", "zh", "ja", "ar"],
value="en",
label="Language"
)
max_tokens = gr.Slider(minimum=50, maximum=1000, value=500, step=50, label="Max Output Tokens")
submit_btn = gr.Button("Process Audio")
with gr.Column():
output_text = gr.Textbox(label="Generated Response", lines=10)
submit_btn.click(
fn=process_audio,
inputs=[audio_input, model_selector, language, max_tokens],
outputs=output_text
)
# Launch the app
if __name__ == "__main__":
demo.queue().launch(share=True)