andrevp commited on
Commit
35ab828
·
verified ·
1 Parent(s): 505acff

Add full duplex streaming mode (streaming.py)

Browse files
Files changed (1) hide show
  1. streaming.py +590 -0
streaming.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full duplex streaming mode for MiniCPM-o 4.5 MLX.
2
+
3
+ Captures screen video + system audio, processes through the model in real-time,
4
+ and outputs text analysis with optional TTS playback.
5
+
6
+ Architecture:
7
+ [Screen 1fps] + [Audio 16kHz] -> ChunkSynchronizer -> DuplexGenerator -> TTSPlayback
8
+ """
9
+
10
+ import queue
11
+ import threading
12
+ import time
13
+ from typing import Optional
14
+
15
+ import mlx.core as mx
16
+ import numpy as np
17
+
18
+
19
+ class ScreenCapture:
20
+ """Capture screen region at 1fps using mss.
21
+
22
+ Produces (H, W, C) float32 frames resized to 448x448.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ out_queue: queue.Queue,
28
+ region: Optional[tuple] = None,
29
+ fps: float = 1.0,
30
+ target_size: int = 448,
31
+ ):
32
+ self.out_queue = out_queue
33
+ self.region = region # (x, y, w, h) or None for primary monitor
34
+ self.fps = fps
35
+ self.target_size = target_size
36
+ self._stop = threading.Event()
37
+ self._thread: Optional[threading.Thread] = None
38
+
39
+ def start(self):
40
+ self._stop.clear()
41
+ self._thread = threading.Thread(target=self._run, daemon=True)
42
+ self._thread.start()
43
+
44
+ def stop(self):
45
+ self._stop.set()
46
+ if self._thread:
47
+ self._thread.join(timeout=2)
48
+
49
+ def _run(self):
50
+ import mss
51
+ from PIL import Image
52
+
53
+ with mss.mss() as sct:
54
+ if self.region:
55
+ x, y, w, h = self.region
56
+ monitor = {"left": x, "top": y, "width": w, "height": h}
57
+ else:
58
+ monitor = sct.monitors[1] # Primary monitor
59
+
60
+ while not self._stop.is_set():
61
+ t0 = time.time()
62
+ screenshot = sct.grab(monitor)
63
+ # Convert to PIL Image, resize, convert to float32
64
+ img = Image.frombytes("RGB", screenshot.size, screenshot.rgb)
65
+ img = img.resize(
66
+ (self.target_size, self.target_size), Image.BILINEAR
67
+ )
68
+ frame = np.array(img, dtype=np.float32) / 255.0 # (H, W, 3)
69
+
70
+ try:
71
+ self.out_queue.put_nowait(
72
+ {"type": "video", "frame": frame, "time": time.time()}
73
+ )
74
+ except queue.Full:
75
+ pass # Drop frame if queue full
76
+
77
+ elapsed = time.time() - t0
78
+ sleep_time = max(0, (1.0 / self.fps) - elapsed)
79
+ if sleep_time > 0:
80
+ self._stop.wait(sleep_time)
81
+
82
+
83
+ class AudioCapture:
84
+ """Capture system audio at 16kHz using sounddevice.
85
+
86
+ Uses BlackHole virtual audio device for system audio loopback on macOS.
87
+ Produces 1-second mono float32 audio chunks.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ out_queue: queue.Queue,
93
+ device: Optional[str] = None,
94
+ sample_rate: int = 16000,
95
+ chunk_seconds: float = 1.0,
96
+ ):
97
+ self.out_queue = out_queue
98
+ self.device = device # Device name or index
99
+ self.sample_rate = sample_rate
100
+ self.chunk_seconds = chunk_seconds
101
+ self.chunk_samples = int(sample_rate * chunk_seconds)
102
+ self._stop = threading.Event()
103
+ self._thread: Optional[threading.Thread] = None
104
+
105
+ def start(self):
106
+ self._stop.clear()
107
+ self._thread = threading.Thread(target=self._run, daemon=True)
108
+ self._thread.start()
109
+
110
+ def stop(self):
111
+ self._stop.set()
112
+ if self._thread:
113
+ self._thread.join(timeout=2)
114
+
115
+ def _find_device(self):
116
+ """Find audio device by name."""
117
+ import sounddevice as sd
118
+
119
+ if self.device is None:
120
+ return None # Use default
121
+
122
+ if isinstance(self.device, int):
123
+ return self.device
124
+
125
+ devices = sd.query_devices()
126
+ for i, d in enumerate(devices):
127
+ if self.device.lower() in d["name"].lower() and d["max_input_channels"] > 0:
128
+ return i
129
+
130
+ print(f"Warning: Audio device '{self.device}' not found, using default.")
131
+ return None
132
+
133
+ def _run(self):
134
+ import sounddevice as sd
135
+
136
+ device_id = self._find_device()
137
+ buffer = np.array([], dtype=np.float32)
138
+
139
+ def callback(indata, frames, time_info, status):
140
+ nonlocal buffer
141
+ if status:
142
+ pass # Ignore overflow/underflow
143
+ mono = indata.mean(axis=1) if indata.ndim > 1 else indata.flatten()
144
+ buffer = np.concatenate([buffer, mono])
145
+
146
+ try:
147
+ with sd.InputStream(
148
+ device=device_id,
149
+ channels=1,
150
+ samplerate=self.sample_rate,
151
+ blocksize=1024,
152
+ callback=callback,
153
+ ):
154
+ while not self._stop.is_set():
155
+ if len(buffer) >= self.chunk_samples:
156
+ chunk = buffer[: self.chunk_samples].copy()
157
+ buffer = buffer[self.chunk_samples :]
158
+ try:
159
+ self.out_queue.put_nowait(
160
+ {
161
+ "type": "audio",
162
+ "data": chunk,
163
+ "time": time.time(),
164
+ }
165
+ )
166
+ except queue.Full:
167
+ pass
168
+ else:
169
+ self._stop.wait(0.05)
170
+ except Exception as e:
171
+ print(f"Audio capture error: {e}")
172
+
173
+
174
+ class ChunkSynchronizer:
175
+ """Synchronize video frames and audio into 1-second chunks.
176
+
177
+ Pairs the latest video frame with each 1-second audio chunk.
178
+ Runs mel processing on the audio.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ raw_queue: queue.Queue,
184
+ sync_queue: queue.Queue,
185
+ mel_processor,
186
+ ):
187
+ self.raw_queue = raw_queue
188
+ self.sync_queue = sync_queue
189
+ self.mel_processor = mel_processor
190
+ self._stop = threading.Event()
191
+ self._thread: Optional[threading.Thread] = None
192
+ self._latest_frame: Optional[np.ndarray] = None
193
+
194
+ def start(self):
195
+ self._stop.clear()
196
+ self._thread = threading.Thread(target=self._run, daemon=True)
197
+ self._thread.start()
198
+
199
+ def stop(self):
200
+ self._stop.set()
201
+ if self._thread:
202
+ self._thread.join(timeout=2)
203
+
204
+ def _run(self):
205
+ while not self._stop.is_set():
206
+ try:
207
+ item = self.raw_queue.get(timeout=0.1)
208
+ except queue.Empty:
209
+ continue
210
+
211
+ if item["type"] == "video":
212
+ self._latest_frame = item["frame"]
213
+ elif item["type"] == "audio":
214
+ self.mel_processor.add_audio(item["data"])
215
+ mel_chunk = self.mel_processor.get_mel_chunk()
216
+ if mel_chunk is not None:
217
+ try:
218
+ self.sync_queue.put_nowait(
219
+ {
220
+ "video_frame": self._latest_frame,
221
+ "mel_chunk": mel_chunk,
222
+ "time": item["time"],
223
+ }
224
+ )
225
+ except queue.Full:
226
+ pass # Drop if consumer is slow
227
+
228
+
229
+ class DuplexGenerator:
230
+ """Main processing loop for full duplex streaming.
231
+
232
+ Dequeues synchronized chunks, runs model inference, generates text responses,
233
+ and optionally queues TTS audio for playback.
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ model,
239
+ processor,
240
+ sync_queue: queue.Queue,
241
+ tts_queue: Optional[queue.Queue] = None,
242
+ temperature: float = 0.0,
243
+ max_tokens_per_chunk: int = 50,
244
+ enable_tts: bool = False,
245
+ ):
246
+ self.model = model
247
+ self.processor = processor
248
+ self.sync_queue = sync_queue
249
+ self.tts_queue = tts_queue
250
+ self.temperature = temperature
251
+ self.max_tokens = max_tokens_per_chunk
252
+ self.enable_tts = enable_tts
253
+ self._stop = threading.Event()
254
+ self._thread: Optional[threading.Thread] = None
255
+ self.ctx = None
256
+ self.chunk_count = 0
257
+ self.on_text = None # callback(text: str)
258
+ self.on_status = None # callback(status: dict)
259
+
260
+ def start(self):
261
+ self._stop.clear()
262
+ self._thread = threading.Thread(target=self._run, daemon=True)
263
+ self._thread.start()
264
+
265
+ def stop(self):
266
+ self._stop.set()
267
+ if self._thread:
268
+ self._thread.join(timeout=5)
269
+
270
+ def _build_chunk_prompt(self, has_video: bool, has_audio: bool):
271
+ """Build prompt tokens for one streaming chunk.
272
+
273
+ Returns:
274
+ dict with input_ids, image_bound, audio_bound
275
+ """
276
+ tokenizer = self.processor.tokenizer
277
+
278
+ parts = []
279
+ parts.append("<|im_start|>user\n")
280
+
281
+ image_bound = []
282
+ audio_bound = []
283
+
284
+ # Video placeholder
285
+ if has_video:
286
+ # 64 query tokens for resampled image
287
+ n_img_tokens = self.model.config.query_num # 64
288
+ img_placeholder = "<image>" + "<unk>" * n_img_tokens + "</image>"
289
+ parts.append(img_placeholder)
290
+
291
+ # Audio placeholder
292
+ if has_audio:
293
+ # Approximate audio tokens: ~10 after pooling for 1 second
294
+ n_audio_tokens = 10
295
+ audio_placeholder = (
296
+ "<|audio_start|>" + "<unk>" * n_audio_tokens + "<|audio_end|>"
297
+ )
298
+ parts.append(audio_placeholder)
299
+
300
+ parts.append("\nDescribe what you see and hear.<|im_end|>\n")
301
+ parts.append("<|im_start|>assistant\n")
302
+
303
+ text = "".join(parts)
304
+ tokenized = tokenizer(text, return_tensors="np")
305
+ input_ids = mx.array(tokenized["input_ids"])
306
+
307
+ # Find image_bound and audio_bound positions
308
+ ids_list = tokenized["input_ids"][0].tolist()
309
+ unk_id = tokenizer.convert_tokens_to_ids("<unk>")
310
+
311
+ if has_video:
312
+ img_start_id = tokenizer.convert_tokens_to_ids("<image>")
313
+ img_end_id = tokenizer.convert_tokens_to_ids("</image>")
314
+ in_img = False
315
+ start_idx = None
316
+ for i, tok in enumerate(ids_list):
317
+ if tok == img_start_id:
318
+ in_img = True
319
+ start_idx = i + 1
320
+ elif tok == img_end_id and in_img:
321
+ image_bound.append((start_idx, i))
322
+ in_img = False
323
+
324
+ if has_audio:
325
+ audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_start|>")
326
+ audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_end|>")
327
+ in_audio = False
328
+ start_idx = None
329
+ for i, tok in enumerate(ids_list):
330
+ if tok == audio_start_id:
331
+ in_audio = True
332
+ start_idx = i + 1
333
+ elif tok == audio_end_id and in_audio:
334
+ audio_bound.append((start_idx, i))
335
+ in_audio = False
336
+
337
+ return {
338
+ "input_ids": input_ids,
339
+ "image_bound": image_bound if image_bound else None,
340
+ "audio_bound": audio_bound if audio_bound else None,
341
+ }
342
+
343
+ def _prepare_video_frame(self, frame: np.ndarray):
344
+ """Prepare a video frame for model input.
345
+
346
+ Args:
347
+ frame: (H, W, 3) float32 frame
348
+
349
+ Returns:
350
+ (pixel_values, tgt_sizes, patch_attention_mask)
351
+ """
352
+ # Frame is already (448, 448, 3) float32
353
+ # Add batch dimension: (1, H, W, 3)
354
+ pv = mx.array(frame[np.newaxis, ...])
355
+
356
+ # Compute patch sizes
357
+ h_patches = frame.shape[0] // 14 # 32
358
+ w_patches = frame.shape[1] // 14 # 32
359
+ tgt_sizes = mx.array([[h_patches, w_patches]], dtype=mx.int32)
360
+
361
+ total_patches = h_patches * w_patches
362
+ patch_attention_mask = mx.ones((1, total_patches), dtype=mx.bool_)
363
+
364
+ return pv, tgt_sizes, patch_attention_mask
365
+
366
+ def _run(self):
367
+ # Initialize streaming context
368
+ self.ctx = self.model.init_streaming()
369
+ self.chunk_count = 0
370
+
371
+ while not self._stop.is_set():
372
+ try:
373
+ chunk = self.sync_queue.get(timeout=0.5)
374
+ except queue.Empty:
375
+ continue
376
+
377
+ t0 = time.time()
378
+ self.chunk_count += 1
379
+
380
+ video_frame = chunk.get("video_frame")
381
+ mel_chunk = chunk.get("mel_chunk")
382
+
383
+ has_video = video_frame is not None
384
+ has_audio = mel_chunk is not None
385
+
386
+ if not has_video and not has_audio:
387
+ continue
388
+
389
+ # Build prompt for this chunk
390
+ prompt = self._build_chunk_prompt(has_video, has_audio)
391
+
392
+ # Prepare video
393
+ pixel_values = None
394
+ tgt_sizes = None
395
+ patch_attention_mask = None
396
+ if has_video:
397
+ pixel_values, tgt_sizes, patch_attention_mask = (
398
+ self._prepare_video_frame(video_frame)
399
+ )
400
+
401
+ # Process chunk through model
402
+ logits = self.model.process_streaming_chunk(
403
+ ctx=self.ctx,
404
+ video_frame=pixel_values,
405
+ audio_chunk=mel_chunk,
406
+ prompt_tokens=prompt["input_ids"],
407
+ image_bound=prompt["image_bound"],
408
+ audio_bound=prompt["audio_bound"],
409
+ tgt_sizes=tgt_sizes,
410
+ patch_attention_mask=patch_attention_mask,
411
+ )
412
+
413
+ # Generate text response
414
+ tokens = self.model.streaming_generate(
415
+ ctx=self.ctx,
416
+ logits=logits,
417
+ tokenizer=self.processor.tokenizer,
418
+ max_tokens=self.max_tokens,
419
+ temperature=self.temperature,
420
+ )
421
+
422
+ elapsed = time.time() - t0
423
+
424
+ if tokens:
425
+ text = self.processor.tokenizer.decode(
426
+ tokens, skip_special_tokens=True
427
+ )
428
+ if self.on_text and text.strip():
429
+ self.on_text(text.strip())
430
+
431
+ # TTS if enabled
432
+ if self.enable_tts and self.tts_queue and tokens:
433
+ self.tts_queue.put_nowait(
434
+ {"tokens": tokens, "text": text}
435
+ )
436
+
437
+ if self.on_status:
438
+ self.on_status(
439
+ {
440
+ "chunk": self.chunk_count,
441
+ "mode": self.ctx.mode,
442
+ "cache_tokens": self.ctx.total_tokens,
443
+ "latency_ms": int(elapsed * 1000),
444
+ "mem_gb": mx.get_peak_memory() / 1e9,
445
+ }
446
+ )
447
+
448
+
449
+ class TTSPlayback:
450
+ """Dequeue TTS tokens, convert to audio, and play back.
451
+
452
+ Uses Token2wav vocoder for audio synthesis and sounddevice for playback.
453
+ """
454
+
455
+ def __init__(self, tts_queue: queue.Queue, sample_rate: int = 24000):
456
+ self.tts_queue = tts_queue
457
+ self.sample_rate = sample_rate
458
+ self._stop = threading.Event()
459
+ self._thread: Optional[threading.Thread] = None
460
+ self._vocoder = None
461
+
462
+ def start(self):
463
+ self._stop.clear()
464
+ self._thread = threading.Thread(target=self._run, daemon=True)
465
+ self._thread.start()
466
+
467
+ def stop(self):
468
+ self._stop.set()
469
+ if self._thread:
470
+ self._thread.join(timeout=2)
471
+
472
+ def _run(self):
473
+ import sounddevice as sd
474
+
475
+ # Try loading vocoder
476
+ try:
477
+ from stepaudio2 import Token2wav
478
+ self._vocoder = Token2wav()
479
+ except ImportError:
480
+ print("TTSPlayback: Token2wav not available, TTS disabled.")
481
+ return
482
+
483
+ while not self._stop.is_set():
484
+ try:
485
+ item = self.tts_queue.get(timeout=0.5)
486
+ except queue.Empty:
487
+ continue
488
+
489
+ tokens = item.get("tokens", [])
490
+ if not tokens:
491
+ continue
492
+
493
+ try:
494
+ import io
495
+ import soundfile as sf
496
+
497
+ wav_bytes = self._vocoder(tokens, None)
498
+ waveform, sr = sf.read(io.BytesIO(wav_bytes))
499
+ sd.play(waveform, sr, blocking=False)
500
+ except Exception as e:
501
+ print(f"TTS playback error: {e}")
502
+
503
+
504
+ def run_live_mode(model, processor, args):
505
+ """Run full duplex streaming mode.
506
+
507
+ Args:
508
+ model: loaded MiniCPM-o model
509
+ processor: tokenizer/processor
510
+ args: argparse namespace with capture_region, audio_device, tts options
511
+ """
512
+ from mlx_vlm.models.minicpmo.audio import StreamingMelProcessor
513
+
514
+ print("Starting live streaming mode...")
515
+ print("Press Ctrl+C to stop.\n")
516
+
517
+ # Create queues
518
+ raw_queue = queue.Queue(maxsize=30)
519
+ sync_queue = queue.Queue(maxsize=10)
520
+ tts_queue = queue.Queue(maxsize=10) if args.tts else None
521
+
522
+ # Create mel processor
523
+ mel_processor = StreamingMelProcessor(sample_rate=16000)
524
+
525
+ # Parse capture region
526
+ region = None
527
+ if hasattr(args, "capture_region") and args.capture_region:
528
+ parts = args.capture_region.split(",")
529
+ if len(parts) == 4:
530
+ region = tuple(int(p) for p in parts)
531
+
532
+ # Create threads
533
+ screen = ScreenCapture(raw_queue, region=region, fps=1.0)
534
+ audio_dev = getattr(args, "audio_device", "BlackHole")
535
+ audio = AudioCapture(raw_queue, device=audio_dev, sample_rate=16000)
536
+ sync = ChunkSynchronizer(raw_queue, sync_queue, mel_processor)
537
+
538
+ generator = DuplexGenerator(
539
+ model,
540
+ processor,
541
+ sync_queue,
542
+ tts_queue=tts_queue,
543
+ temperature=getattr(args, "temp", 0.0),
544
+ max_tokens_per_chunk=getattr(args, "max_tokens", 50),
545
+ enable_tts=getattr(args, "tts", False),
546
+ )
547
+
548
+ tts_playback = None
549
+ if tts_queue:
550
+ tts_playback = TTSPlayback(tts_queue)
551
+
552
+ # Set up callbacks
553
+ def on_text(text):
554
+ print(f"[{generator.chunk_count}] {text}")
555
+
556
+ def on_status(status):
557
+ print(
558
+ f" >> chunk={status['chunk']} mode={status['mode']} "
559
+ f"cache={status['cache_tokens']}tok "
560
+ f"latency={status['latency_ms']}ms "
561
+ f"mem={status['mem_gb']:.1f}GB",
562
+ flush=True,
563
+ )
564
+
565
+ generator.on_text = on_text
566
+ generator.on_status = on_status
567
+
568
+ # Start all threads
569
+ screen.start()
570
+ audio.start()
571
+ sync.start()
572
+ generator.start()
573
+ if tts_playback:
574
+ tts_playback.start()
575
+
576
+ print("Live mode active. Capturing screen + audio...\n")
577
+
578
+ try:
579
+ while True:
580
+ time.sleep(0.5)
581
+ except KeyboardInterrupt:
582
+ print("\nStopping live mode...")
583
+ finally:
584
+ screen.stop()
585
+ audio.stop()
586
+ sync.stop()
587
+ generator.stop()
588
+ if tts_playback:
589
+ tts_playback.stop()
590
+ print("Live mode stopped.")