| from __future__ import annotations |
|
|
| import os |
| import subprocess |
| import tempfile |
| import time |
| from functools import lru_cache |
| from pathlib import Path |
|
|
| import gradio as gr |
|
|
| try: |
| import spaces |
| except ImportError: |
| class _SpacesFallback: |
| @staticmethod |
| def GPU(func): |
| return func |
|
|
| spaces = _SpacesFallback() |
|
|
| from src.hf_inference import MossAudioHFInference, read_env_model_id, resolve_device |
|
|
| TITLE = "MOSS-Audio-8B-Thinking Demo" |
|
|
| DEFAULT_QUESTION = "Describe this audio." |
| DEFAULT_MAX_NEW_TOKENS = 1024 |
| DEFAULT_TEMPERATURE = 1.0 |
| DEFAULT_TOP_P = 1.0 |
| DEFAULT_TOP_K = 50 |
| VIDEO_EXTENSIONS = {".mp4"} |
|
|
|
|
| @lru_cache(maxsize=2) |
| def get_inference(model_name_or_path: str, device: str) -> MossAudioHFInference: |
| return MossAudioHFInference( |
| model_name_or_path=model_name_or_path, |
| device=device, |
| torch_dtype="auto", |
| enable_time_marker=True, |
| ) |
|
|
|
|
| def format_status(model_name_or_path: str, device: str, elapsed_seconds: float) -> str: |
| return ( |
| f"Model: `{model_name_or_path}` \n" |
| f"Device: `{device}` \n" |
| f"Elapsed: `{elapsed_seconds:.2f}s`" |
| ) |
|
|
|
|
| def convert_media_to_mp3(media_path: str, output_path: str) -> None: |
| command = [ |
| "ffmpeg", |
| "-y", |
| "-i", |
| media_path, |
| "-vn", |
| "-acodec", |
| "libmp3lame", |
| output_path, |
| ] |
| try: |
| subprocess.run( |
| command, |
| check=True, |
| stdout=subprocess.DEVNULL, |
| stderr=subprocess.PIPE, |
| text=True, |
| ) |
| except subprocess.CalledProcessError as exc: |
| raise gr.Error( |
| f"Failed to extract audio from the uploaded media. Please make sure the mp4 file is valid and decodable.\n{exc.stderr}" |
| ) from exc |
|
|
|
|
| def resolve_media_path(audio_path: str | None, video_path: str | None) -> str | None: |
| if video_path: |
| return video_path |
| return audio_path |
|
|
|
|
| @spaces.GPU |
| def run_inference( |
| audio_path: str | None, |
| video_path: str | None, |
| question: str, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| top_k: int, |
| ): |
| prompt = (question or "").strip() or DEFAULT_QUESTION |
| model_name_or_path = read_env_model_id() |
| device = resolve_device() |
|
|
| try: |
| inference = get_inference(model_name_or_path, device) |
| except Exception as exc: |
| raise gr.Error( |
| f"Failed to load the model. Please check the weights path or Hugging Face download status.\n{exc}" |
| ) from exc |
|
|
| media_path = resolve_media_path(audio_path, video_path) |
|
|
| try: |
| started_at = time.perf_counter() |
| with tempfile.TemporaryDirectory(prefix="moss-audio-") as temp_dir: |
| prepared_audio_path = media_path |
| if media_path and Path(media_path).suffix.lower() in VIDEO_EXTENSIONS: |
| prepared_audio_path = os.path.join(temp_dir, "input.mp3") |
| convert_media_to_mp3(media_path, prepared_audio_path) |
|
|
| answer = inference.generate( |
| question=prompt, |
| audio_path=prepared_audio_path, |
| max_new_tokens=max_new_tokens, |
| do_sample=temperature > 0, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| ) |
| elapsed_seconds = time.perf_counter() - started_at |
| except Exception as exc: |
| raise gr.Error( |
| f"Inference failed. Please make sure the uploaded file is readable and the format is supported.\n{exc}" |
| ) from exc |
|
|
| return answer, format_status(model_name_or_path, device, elapsed_seconds) |
|
|
|
|
| with gr.Blocks(title=TITLE) as demo: |
| gr.Markdown(f"# {TITLE}") |
|
|
| with gr.Row(): |
| with gr.Column(scale=5): |
| audio_input = gr.Audio( |
| label="Audio", |
| sources=["upload", "microphone"], |
| type="filepath", |
| ) |
| with gr.Accordion("Optional Video Input (.mp4)", open=False): |
| gr.Markdown( |
| "Upload an mp4 only when needed. If a video is provided, its audio track will be extracted and used for inference." |
| ) |
| video_input = gr.File( |
| label="Video File", |
| file_types=[".mp4"], |
| type="filepath", |
| ) |
| question_input = gr.Textbox( |
| label="Prompt", |
| lines=4, |
| value=DEFAULT_QUESTION, |
| placeholder="For example: Please transcribe this audio. Describe the sounds in this clip. What emotion does the speaker convey?", |
| ) |
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| max_new_tokens_input = gr.Slider( |
| minimum=64, |
| maximum=2048, |
| value=DEFAULT_MAX_NEW_TOKENS, |
| step=32, |
| label="Max New Tokens", |
| ) |
| temperature_input = gr.Slider( |
| minimum=0.0, |
| maximum=1.5, |
| value=DEFAULT_TEMPERATURE, |
| step=0.1, |
| label="Temperature", |
| ) |
| top_p_input = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=DEFAULT_TOP_P, |
| step=0.05, |
| label="Top-p", |
| ) |
| top_k_input = gr.Slider( |
| minimum=1, |
| maximum=100, |
| value=DEFAULT_TOP_K, |
| step=1, |
| label="Top-k", |
| ) |
|
|
| with gr.Row(): |
| submit_btn = gr.Button("Generate", variant="primary") |
| gr.ClearButton( |
| [ |
| audio_input, |
| video_input, |
| question_input, |
| max_new_tokens_input, |
| temperature_input, |
| top_p_input, |
| top_k_input, |
| ], |
| value="Clear", |
| ) |
|
|
| with gr.Column(scale=5): |
| output_text = gr.Textbox(label="Output", lines=16) |
| status_text = gr.Markdown("Waiting for input.") |
|
|
| gr.Examples( |
| examples=[ |
| ["Describe this audio."], |
| ["Please transcribe this audio."], |
| ["What is happening in this audio clip?"], |
| ["Describe the speaker's voice characteristics in detail."], |
| ["What emotion does the speaker convey?"], |
| ], |
| inputs=[question_input], |
| label="Prompt Examples", |
| ) |
|
|
| submit_btn.click( |
| fn=run_inference, |
| inputs=[ |
| audio_input, |
| video_input, |
| question_input, |
| max_new_tokens_input, |
| temperature_input, |
| top_p_input, |
| top_k_input, |
| ], |
| outputs=[output_text, status_text], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| server_name = os.environ.get("MOSS_AUDIO_SERVER_NAME", "0.0.0.0") |
| server_port = int(os.environ.get("MOSS_AUDIO_SERVER_PORT", "7860")) |
| demo.queue(max_size=8).launch( |
| server_name=server_name, |
| server_port=server_port, |
| ssr_mode=False, |
| ) |
|
|