File size: 8,846 Bytes
0a81958
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
#!/usr/bin/env python3
"""
Test VibeVoice vLLM API with Streaming and Optional Hotwords Support.

This script tests ASR transcription via the vLLM OpenAI-compatible API.
By default, it runs standard transcription without hotwords.

Optionally, you can provide hotwords (context_info) to improve recognition
of domain-specific content like proper nouns, technical terms, and speaker names.
Hotwords are embedded in the prompt as "with extra info: {hotwords}".

Usage:
    python test_api_with_hotwords.py [audio_path] [--url URL] [--hotwords "word1,word2"]
    
Examples:
    # Standard transcription (no hotwords)
    python3 test_api.py audio.wav
    
    # With hotwords for better recognition of specific terms
    python3 test_api.py audio.wav --hotwords "Microsoft,Azure,VibeVoice"
"""
import requests
import json
import base64
import time
import sys
import os
import subprocess
import argparse


def _guess_mime_type(path: str) -> str:
    """Guess MIME type from file extension."""
    ext = os.path.splitext(path)[1].lower()
    mime_map = {
        ".wav": "audio/wav",
        ".mp3": "audio/mpeg",
        ".m4a": "audio/mp4",
        ".mp4": "video/mp4",
        ".flac": "audio/flac",
        ".ogg": "audio/ogg",
        ".opus": "audio/ogg",
    }
    return mime_map.get(ext, "application/octet-stream")


def _get_duration_seconds_ffprobe(path: str) -> float:
    """Get audio duration using ffprobe."""
    cmd = [
        "ffprobe", "-v", "error",
        "-show_entries", "format=duration",
        "-of", "default=noprint_wrappers=1:nokey=1",
        path,
    ]
    out = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode("utf-8").strip()
    return float(out)


def _is_video_file(path: str) -> bool:
    """Check if the file is a video file that needs audio extraction."""
    ext = os.path.splitext(path)[1].lower()
    return ext in (".mp4", ".m4v", ".mov", ".webm", ".avi", ".mkv")


def _extract_audio_from_video(video_path: str) -> str:
    """
    Extract audio from video file (mp4/mov/webm) to a temporary mp3 file.
    Returns the path to the extracted audio file.
    """
    import tempfile
    # Create temp file with .mp3 extension
    fd, audio_path = tempfile.mkstemp(suffix=".mp3")
    os.close(fd)
    
    cmd = [
        "ffmpeg", "-y", "-i", video_path,
        "-vn",  # No video
        "-acodec", "libmp3lame",
        "-q:a", "2",  # High quality
        audio_path
    ]
    subprocess.run(cmd, check=True, capture_output=True)
    return audio_path


def test_transcription_with_hotwords(
    audio_path: str,
    context_info: str = None,
    base_url: str = "http://localhost:8000",
):
    """
    Test ASR transcription with customized hotwords.
    
    Hotwords are embedded in the prompt text as "with extra info: {hotwords}".
    This helps the model recognize domain-specific terms more accurately.
    
    Args:
        audio_path: Path to the audio file
        context_info: Hotwords string (e.g., "Microsoft,Azure,VibeVoice")
        base_url: vLLM server URL
    """
    
    print(f"=" * 70)
    print(f"Testing Customized Hotwords Support")
    print(f"=" * 70)
    print(f"Input file: {audio_path}")
    print(f"Hotwords: {context_info or '(none)'}")
    print()
    
    # Handle video files: extract audio first
    temp_audio_path = None
    actual_audio_path = audio_path
    if _is_video_file(audio_path):
        print(f"🎬 Detected video file, extracting audio...")
        temp_audio_path = _extract_audio_from_video(audio_path)
        actual_audio_path = temp_audio_path
        print(f"βœ… Audio extracted to: {temp_audio_path}")
    
    # Load audio
    try:
        duration = _get_duration_seconds_ffprobe(actual_audio_path)
        print(f"Audio duration: {duration:.2f} seconds")
        
        with open(actual_audio_path, "rb") as f:
            audio_bytes = f.read()
        
        audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
        print(f"Audio size: {len(audio_bytes)} bytes")
        
    except Exception as e:
        print(f"Error preparing audio: {e}")
        # Cleanup temp file if created
        if temp_audio_path and os.path.exists(temp_audio_path):
            os.remove(temp_audio_path)
        return

    # Build the request
    url = f"{base_url}/v1/chat/completions"
    
    show_keys = ["Start time", "End time", "Speaker ID", "Content"]
    
    # Build prompt with optional hotwords
    # Hotwords are embedded as "with extra info: {hotwords}" in the prompt
    if context_info and context_info.strip():
        prompt_text = (
            f"This is a {duration:.2f} seconds audio, with extra info: {context_info.strip()}\n\n"
            f"Please transcribe it with these keys: " + ", ".join(show_keys)
        )
        print(f"\nπŸ“ Hotwords embedded in prompt: '{context_info}'")
    else:
        prompt_text = (
            f"This is a {duration:.2f} seconds audio, please transcribe it with these keys: "
            + ", ".join(show_keys)
        )
        print(f"\nπŸ“ No hotwords provided")

    mime = _guess_mime_type(actual_audio_path)
    data_url = f"data:{mime};base64,{audio_b64}"

    payload = {
        "model": "vibevoice",
        "messages": [
            {
                "role": "system",
                "content": "You are a helpful assistant that transcribes audio input into text output in JSON format."
            },
            {
                "role": "user",
                "content": [
                    {"type": "audio_url", "audio_url": {"url": data_url}},
                    {"type": "text", "text": prompt_text}
                ]
            }
        ],
        "max_tokens": 32768,       
        "temperature": 0.0,      
        "stream": True,
        "top_p": 1.0,
    }
    
    print(f"\n{'=' * 70}")
    print(f"Sending request to {url}")
    print(f"{'=' * 70}")
    
    t0 = time.time()
    try:
        response = requests.post(url, json=payload, stream=True, timeout=12000)
        
        if response.status_code == 200:
            print("\nβœ… Response received. Streaming content:\n")
            print("-" * 50)

            printed = ""
            for line in response.iter_lines():
                if line:
                    decoded_line = line.decode('utf-8')
                    
                    if decoded_line.startswith("data: "):
                        json_str = decoded_line[6:] 
                        if json_str.strip() == "[DONE]":
                            print("\n" + "-" * 50)
                            print("βœ… [Finished]")
                            break
                        try:
                            data = json.loads(json_str)
                            delta = data['choices'][0]['delta']
                            content = delta.get('content', '')
                            if content:
                                if content.startswith(printed):
                                    to_print = content[len(printed):]
                                else:
                                    to_print = content
                                if to_print:
                                    print(to_print, end='', flush=True)
                                    printed += to_print
                        except json.JSONDecodeError:
                            pass
        else:
            print(f"❌ Error: {response.status_code}")
            print(response.text)
            
    except requests.exceptions.Timeout:
        print("❌ Request timed out!")
    except Exception as e:
        print(f"❌ Error: {e}")
        
    elapsed = time.time() - t0
    print(f"\n{'=' * 70}")
    print(f"⏱️  Total time elapsed: {elapsed:.2f}s")
    print(f"πŸ“Š RTF (Real-Time Factor): {elapsed / duration:.2f}x")
    print(f"{'=' * 70}")
    
    # Cleanup temp audio file if created
    if temp_audio_path and os.path.exists(temp_audio_path):
        os.remove(temp_audio_path)
        print(f"πŸ—‘οΈ  Cleaned up temp file: {temp_audio_path}")


def main():
    parser = argparse.ArgumentParser(
        description="Test VibeVoice vLLM API with Customized Hotwords"
    )
    parser.add_argument(
        "audio_path",
        help="Path to audio file (wav, mp3, flac, etc.) or video file"
    )
    parser.add_argument(
        "--url",
        default="http://localhost:8000",
        help="vLLM server URL (default: http://localhost:8000)"
    )
    parser.add_argument(
        "--hotwords",
        type=str,
        default=None,
        help="Hotwords to improve recognition (e.g., 'Microsoft,Azure,VibeVoice')"
    )
    
    args = parser.parse_args()
    
    # Run test
    test_transcription_with_hotwords(
        audio_path=args.audio_path,
        context_info=args.hotwords,
        base_url=args.url,
    )


if __name__ == "__main__":
    main()