Voxtral / app.py
MohamedRashad's picture
Refactor model loading to use direct model paths instead of snapshot downloads
aa14541
raw
history blame
3.14 kB
import gradio as gr
import spaces
import torch
from transformers import AutoProcessor, VoxtralForConditionalGeneration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load model and processor
voxtral_mini_processor = AutoProcessor.from_pretrained("MohamedRashad/Voxtral-Mini-3B-2507-transformers")
voxtral_mini_model = VoxtralForConditionalGeneration.from_pretrained("MohamedRashad/Voxtral-Mini-3B-2507-transformers", torch_dtype=torch.bfloat16, device_map=device)
voxtral_small_processor = AutoProcessor.from_pretrained("MohamedRashad/Voxtral-Small-24B-2507-transformers")
voxtral_small_model = VoxtralForConditionalGeneration.from_pretrained("MohamedRashad/Voxtral-Small-24B-2507-transformers", torch_dtype=torch.bfloat16, device_map=device)
@spaces.GPU()
def process_audio(audio_path, model_name, language="en", max_tokens=500):
"""Process audio with selected Voxtral model and return the generated response"""
if not audio_path:
return "Please upload an audio file."
if model_name == "Voxtral Mini (3B)":
model = voxtral_mini_model
processor = voxtral_mini_processor
repo_id = "MohamedRashad/Voxtral-Mini-3B-2507-transformers"
elif model_name == "Voxtral Small (24B)":
model = voxtral_small_model
processor = voxtral_small_processor
repo_id = "MohamedRashad/Voxtral-Small-24B-2507-transformers"
else:
return "Invalid model selected."
inputs = processor.apply_transcription_request(language=language, audio=audio_path, model_id=repo_id)
inputs = inputs.to(device, dtype=torch.bfloat16)
outputs = model.generate(**inputs, max_new_tokens=max_tokens)
decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
return decoded_outputs[0]
# Define Gradio interface
with gr.Blocks(title="Voxtral Demo") as demo:
gr.Markdown("# Voxtral Audio Processing Demo")
gr.Markdown("Upload an audio file and get a transcription/response from Voxtral.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath", label="Upload Audio")
model_selector = gr.Dropdown(
choices=["Voxtral Mini (3B)", "Voxtral Small (24B)"],
value="Voxtral Mini (3B)",
label="Select Model"
)
language = gr.Dropdown(
choices=["en", "fr", "de", "es", "it", "pt", "nl", "ru", "zh", "ja", "ar"],
value="en",
label="Language"
)
max_tokens = gr.Slider(minimum=50, maximum=1000, value=500, step=50, label="Max Output Tokens")
submit_btn = gr.Button("Process Audio")
with gr.Column():
output_text = gr.Textbox(label="Generated Response", lines=10)
submit_btn.click(
fn=process_audio,
inputs=[audio_input, model_selector, language, max_tokens],
outputs=output_text
)
# Launch the app
if __name__ == "__main__":
demo.queue().launch(share=True)