| |
| """ |
| Test VibeVoice vLLM API with Streaming and Optional Hotwords Support. |
| |
| This script tests ASR transcription via the vLLM OpenAI-compatible API. |
| By default, it runs standard transcription without hotwords. |
| |
| Optionally, you can provide hotwords (context_info) to improve recognition |
| of domain-specific content like proper nouns, technical terms, and speaker names. |
| Hotwords are embedded in the prompt as "with extra info: {hotwords}". |
| |
| Usage: |
| python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"] |
| |
| Examples: |
| # Standard transcription (no hotwords) |
| python3 test_api.py audio.wav |
| |
| # With hotwords for better recognition of specific terms |
| python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice" |
| """ |
| import requests |
| import json |
| import base64 |
| import time |
| import sys |
| import os |
| import subprocess |
| import argparse |
|
|
|
|
| def _guess_mime_type(path: str) -> str: |
| """Guess MIME type from file extension.""" |
| ext = os.path.splitext(path)[1].lower() |
| mime_map = { |
| ".wav": "audio/wav", |
| ".mp3": "audio/mpeg", |
| ".m4a": "audio/mp4", |
| ".mp4": "video/mp4", |
| ".flac": "audio/flac", |
| ".ogg": "audio/ogg", |
| ".opus": "audio/ogg", |
| } |
| return mime_map.get(ext, "application/octet-stream") |
|
|
|
|
| def _get_duration_seconds_ffprobe(path: str) -> float: |
| """Get audio duration using ffprobe.""" |
| cmd = [ |
| "ffprobe", "-v", "error", |
| "-show_entries", "format=duration", |
| "-of", "default=noprint_wrappers=1:nokey=1", |
| path, |
| ] |
| out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip() |
| return float(out) |
|
|
|
|
| def _is_video_file(path: str) -> bool: |
| """Check if the file is a video file that needs audio extraction.""" |
| ext = os.path.splitext(path)[1].lower() |
| return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv") |
|
|
|
|
| def _extract_audio_from_video(video_path: str) -> str: |
| """ |
| Extract audio from video file (mp4/mov/webm) to a temporary mp3 file. |
| Returns the path to the extracted audio file. |
| """ |
| import tempfile |
| |
| fd, audio_path = tempfile.mkstemp(suffix=".mp3") |
| os.close(fd) |
| |
| cmd = [ |
| "ffmpeg", "-y", "-i", video_path, |
| "-vn", |
| "-acodec", "libmp3lame", |
| "-q:a", "2", |
| audio_path |
| ] |
| subprocess.run(cmd, check=True, capture_output=True) |
| return audio_path |
|
|
|
|
| def test_transcription_with_hotwords( |
| audio_path: str, |
| context_info: str = None, |
| base_url: str = "http://localhost:8000", |
| ): |
| """ |
| Test ASR transcription with customized hotwords. |
| |
| Hotwords are embedded in the prompt text as "with extra info: {hotwords}". |
| This helps the model recognize domain-specific terms more accurately. |
| |
| Args: |
| audio_path: Path to the audio file |
| context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice") |
| base_url: vLLM server URL |
| """ |
| |
| print(f"=" * 70) |
| print(f"Testing Customized Hotwords Support") |
| print(f"=" * 70) |
| print(f"Input file: {audio_path}") |
| print(f"Hotwords: {context_info or '(none)'}") |
| print() |
| |
| |
| temp_audio_path = None |
| actual_audio_path = audio_path |
| if _is_video_file(audio_path): |
| print(f"π¬ Detected video file, extracting audio...") |
| temp_audio_path = _extract_audio_from_video(audio_path) |
| actual_audio_path = temp_audio_path |
| print(f"β
Audio extracted to: {temp_audio_path}") |
| |
| |
| try: |
| duration = _get_duration_seconds_ffprobe(actual_audio_path) |
| print(f"Audio duration: {duration:.2f} seconds") |
| |
| with open(actual_audio_path, "rb") as f: |
| audio_bytes = f.read() |
| |
| audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") |
| print(f"Audio size: {len(audio_bytes)} bytes") |
| |
| except Exception as e: |
| print(f"Error preparing audio: {e}") |
| |
| if temp_audio_path and os.path.exists(temp_audio_path): |
| os.remove(temp_audio_path) |
| return |
|
|
| |
| url = f"{base_url}/v1/chat/completions" |
| |
| show_keys = ["Start time", "End time", "Speaker ID", "Content"] |
| |
| |
| |
| if context_info and context_info.strip(): |
| prompt_text = ( |
| f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n" |
| f"Please transcribe it with these keys: " + ", ".join(show_keys) |
| ) |
| print(f"\nπ Hotwords embedded in prompt: '{context_info}'") |
| else: |
| prompt_text = ( |
| f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: " |
| + ", ".join(show_keys) |
| ) |
| print(f"\nπ No hotwords provided") |
|
|
| mime = _guess_mime_type(actual_audio_path) |
| data_url = f"data:{mime};base64,{audio_b64}" |
|
|
| payload = { |
| "model": "vibevoice", |
| "messages": [ |
| { |
| "role": "system", |
| "content": "You are a helpful assistant that transcribes audio input into text output in JSON format." |
| }, |
| { |
| "role": "user", |
| "content": [ |
| {"type": "audio_url", "audio_url": {"url": data_url}}, |
| {"type": "text", "text": prompt_text} |
| ] |
| } |
| ], |
| "max_tokens": 32768, |
| "temperature": 0.0, |
| "stream": True, |
| "top_p": 1.0, |
| } |
| |
| print(f"\n{'=' * 70}") |
| print(f"Sending request to {url}") |
| print(f"{'=' * 70}") |
| |
| t0 = time.time() |
| try: |
| response = requests.post(url, json=payload, stream=True, timeout=12000) |
| |
| if response.status_code == 200: |
| print("\nβ
Response received. Streaming content:\n") |
| print("-" * 50) |
|
|
| printed = "" |
| for line in response.iter_lines(): |
| if line: |
| decoded_line = line.decode('utf-8') |
| |
| if decoded_line.startswith("data: "): |
| json_str = decoded_line[6:] |
| if json_str.strip() == "[DONE]": |
| print("\n" + "-" * 50) |
| print("β
[Finished]") |
| break |
| try: |
| data = json.loads(json_str) |
| delta = data['choices'][0]['delta'] |
| content = delta.get('content', '') |
| if content: |
| if content.startswith(printed): |
| to_print = content[len(printed):] |
| else: |
| to_print = content |
| if to_print: |
| print(to_print, end='', flush=True) |
| printed += to_print |
| except json.JSONDecodeError: |
| pass |
| else: |
| print(f"β Error: {response.status_code}") |
| print(response.text) |
| |
| except requests.exceptions.Timeout: |
| print("β Request timed out!") |
| except Exception as e: |
| print(f"β Error: {e}") |
| |
| elapsed = time.time() - t0 |
| print(f"\n{'=' * 70}") |
| print(f"β±οΈ Total time elapsed: {elapsed:.2f}s") |
| print(f"π RTF (Real-Time Factor): {elapsed / duration:.2f}x") |
| print(f"{'=' * 70}") |
| |
| |
| if temp_audio_path and os.path.exists(temp_audio_path): |
| os.remove(temp_audio_path) |
| print(f"ποΈ Cleaned up temp file: {temp_audio_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Test VibeVoice vLLM API with Customized Hotwords" |
| ) |
| parser.add_argument( |
| "audio_path", |
| help="Path to audio file (wav, mp3, flac, etc.) or video file" |
| ) |
| parser.add_argument( |
| "--url", |
| default="http://localhost:8000", |
| help="vLLM server URL (default: http://localhost:8000)" |
| ) |
| parser.add_argument( |
| "--hotwords", |
| type=str, |
| default=None, |
| help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')" |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| test_transcription_with_hotwords( |
| audio_path=args.audio_path, |
| context_info=args.hotwords, |
| base_url=args.url, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|