dvalle08 commited on
Commit
ef8254e
·
1 Parent(s): d9055df

Enhance VAD configuration and add fallback participant ID support: Update VAD_MIN_SPEECH_DURATION and VAD_MIN_SILENCE_DURATION for faster response times. Introduce fallback participant ID logic in MetricsCollector to handle cases where participant metadata is absent, ensuring accurate trace reporting.

Browse files
.env.example CHANGED
@@ -61,6 +61,6 @@ LIVEKIT_PRE_CONNECT_AUDIO=true
61
  LIVEKIT_PRE_CONNECT_TIMEOUT=3.0
62
 
63
  # Voice Activity Detection (VAD) Configuration - OPTIMIZED FOR FALSE DETECTION FIX
64
- VAD_MIN_SPEECH_DURATION=0.25 # Require 250ms of speech before activation
65
- VAD_MIN_SILENCE_DURATION=0.5 # Require 500ms of silence before deactivation
66
  VAD_THRESHOLD=0.6 # Higher = less sensitive to noise (0.5 is default)
 
61
  LIVEKIT_PRE_CONNECT_TIMEOUT=3.0
62
 
63
  # Voice Activity Detection (VAD) Configuration - OPTIMIZED FOR FALSE DETECTION FIX
64
+ VAD_MIN_SPEECH_DURATION=0.18 # Require 180ms of speech before activation (faster turn pickup)
65
+ VAD_MIN_SILENCE_DURATION=0.30 # Require 300ms of silence before deactivation (faster turn end)
66
  VAD_THRESHOLD=0.6 # Higher = less sensitive to noise (0.5 is default)
src/agent/agent.py CHANGED
@@ -43,6 +43,13 @@ def _fallback_session_prefix() -> str | None:
43
  return None
44
 
45
 
 
 
 
 
 
 
 
46
  def setup_langfuse_tracer() -> TracerProvider | None:
47
  """Configure LiveKit telemetry tracer to export traces to Langfuse."""
48
  global _langfuse_tracer_provider
@@ -169,6 +176,7 @@ async def session_handler(ctx: agents.JobContext) -> None:
169
  room_id=initial_room_id,
170
  participant_id=initial_participant_id,
171
  fallback_session_prefix=_fallback_session_prefix(),
 
172
  langfuse_enabled=trace_provider is not None,
173
  )
174
 
 
43
  return None
44
 
45
 
46
+ def _fallback_participant_prefix() -> str | None:
47
+ """Use console-prefixed fallback participant id when running `... console`."""
48
+ if any(arg == "console" for arg in sys.argv[1:]):
49
+ return "console"
50
+ return None
51
+
52
+
53
  def setup_langfuse_tracer() -> TracerProvider | None:
54
  """Configure LiveKit telemetry tracer to export traces to Langfuse."""
55
  global _langfuse_tracer_provider
 
176
  room_id=initial_room_id,
177
  participant_id=initial_participant_id,
178
  fallback_session_prefix=_fallback_session_prefix(),
179
+ fallback_participant_prefix=_fallback_participant_prefix(),
180
  langfuse_enabled=trace_provider is not None,
181
  )
182
 
src/agent/metrics_collector.py CHANGED
@@ -180,6 +180,7 @@ class MetricsCollector:
180
  room_id: Optional[str] = None,
181
  participant_id: Optional[str] = None,
182
  fallback_session_prefix: Optional[str] = None,
 
183
  langfuse_enabled: bool = False,
184
  ) -> None:
185
  """Initialize metrics collector.
@@ -192,6 +193,8 @@ class MetricsCollector:
192
  participant_id: LiveKit participant identity when available
193
  fallback_session_prefix: Prefix used for generated fallback session id
194
  (e.g. "console" -> "console_<uuid>") when no metadata session id exists
 
 
195
  langfuse_enabled: Enable one-trace-per-turn Langfuse traces
196
  """
197
  self._room = room
@@ -208,7 +211,14 @@ class MetricsCollector:
208
  fallback_session_prefix
209
  )
210
  self._session_id = self._fallback_session_id or self.UNKNOWN_SESSION_ID
211
- self._participant_id = participant_id or self.UNKNOWN_PARTICIPANT_ID
 
 
 
 
 
 
 
212
  self._langfuse_enabled = langfuse_enabled
213
  self._pending_trace_turns: deque[TraceTurn] = deque()
214
  self._trace_lock = asyncio.Lock()
@@ -271,8 +281,10 @@ class MetricsCollector:
271
  ):
272
  turn.session_id = self._session_id
273
  if (
274
- turn.participant_id == self.UNKNOWN_PARTICIPANT_ID
275
- and self._participant_id != self.UNKNOWN_PARTICIPANT_ID
 
 
276
  ):
277
  turn.participant_id = self._participant_id
278
 
@@ -997,11 +1009,12 @@ class MetricsCollector:
997
  root_start_ns = time_ns()
998
  cursor_ns = root_start_ns
999
 
1000
- with tracer.start_as_current_span(
1001
  "turn",
1002
  context=root_context,
1003
  start_time=root_start_ns,
1004
- ) as turn_span:
 
1005
  turn.trace_id = trace.format_trace_id(turn_span.get_span_context().trace_id)
1006
  turn_span.set_attribute("session_id", turn.session_id)
1007
  turn_span.set_attribute("room_id", turn.room_id)
@@ -1127,7 +1140,7 @@ class MetricsCollector:
1127
  },
1128
  observation_output=str(conversational_latency_ms),
1129
  )
1130
-
1131
  self._close_span_at(turn_span, cursor_ns)
1132
  logger.info(
1133
  "Langfuse turn trace emitted: trace_id=%s turn_id=%s session_id=%s room_id=%s participant_id=%s",
@@ -1155,7 +1168,8 @@ class MetricsCollector:
1155
  actual_duration_ms = max(duration_ms, 0.0) if duration_ms is not None else None
1156
  end_ns = start_ns + self._duration_ms_to_ns(actual_duration_ms or 0.0)
1157
 
1158
- with tracer.start_as_current_span(name, context=context, start_time=start_ns) as span:
 
1159
  if actual_duration_ms is not None:
1160
  span.set_attribute("duration_ms", actual_duration_ms)
1161
  if observation_input is not None:
@@ -1168,6 +1182,7 @@ class MetricsCollector:
1168
  if value is None:
1169
  continue
1170
  span.set_attribute(key, value)
 
1171
  self._close_span_at(span, end_ns)
1172
  return end_ns
1173
 
@@ -1261,6 +1276,12 @@ class MetricsCollector:
1261
  return None
1262
  return f"{normalized_prefix}_{uuid.uuid4()}"
1263
 
 
 
 
 
 
 
1264
  async def _resolve_room_id(self) -> str:
1265
  if self._room_id and self._room_id != self._room_name:
1266
  return self._room_id
 
180
  room_id: Optional[str] = None,
181
  participant_id: Optional[str] = None,
182
  fallback_session_prefix: Optional[str] = None,
183
+ fallback_participant_prefix: Optional[str] = None,
184
  langfuse_enabled: bool = False,
185
  ) -> None:
186
  """Initialize metrics collector.
 
193
  participant_id: LiveKit participant identity when available
194
  fallback_session_prefix: Prefix used for generated fallback session id
195
  (e.g. "console" -> "console_<uuid>") when no metadata session id exists
196
+ fallback_participant_prefix: Prefix used for generated fallback participant id
197
+ (e.g. "console" -> "console_<uuid>") when no participant identity exists
198
  langfuse_enabled: Enable one-trace-per-turn Langfuse traces
199
  """
200
  self._room = room
 
211
  fallback_session_prefix
212
  )
213
  self._session_id = self._fallback_session_id or self.UNKNOWN_SESSION_ID
214
+ self._fallback_participant_id = self._build_fallback_participant_id(
215
+ fallback_participant_prefix
216
+ )
217
+ self._participant_id = (
218
+ self._normalize_optional_text(participant_id)
219
+ or self._fallback_participant_id
220
+ or self.UNKNOWN_PARTICIPANT_ID
221
+ )
222
  self._langfuse_enabled = langfuse_enabled
223
  self._pending_trace_turns: deque[TraceTurn] = deque()
224
  self._trace_lock = asyncio.Lock()
 
281
  ):
282
  turn.session_id = self._session_id
283
  if (
284
+ turn.participant_id
285
+ in {self.UNKNOWN_PARTICIPANT_ID, self._fallback_participant_id}
286
+ and self._participant_id
287
+ not in {self.UNKNOWN_PARTICIPANT_ID, self._fallback_participant_id}
288
  ):
289
  turn.participant_id = self._participant_id
290
 
 
1009
  root_start_ns = time_ns()
1010
  cursor_ns = root_start_ns
1011
 
1012
+ turn_span = tracer.start_span(
1013
  "turn",
1014
  context=root_context,
1015
  start_time=root_start_ns,
1016
+ )
1017
+ try:
1018
  turn.trace_id = trace.format_trace_id(turn_span.get_span_context().trace_id)
1019
  turn_span.set_attribute("session_id", turn.session_id)
1020
  turn_span.set_attribute("room_id", turn.room_id)
 
1140
  },
1141
  observation_output=str(conversational_latency_ms),
1142
  )
1143
+ finally:
1144
  self._close_span_at(turn_span, cursor_ns)
1145
  logger.info(
1146
  "Langfuse turn trace emitted: trace_id=%s turn_id=%s session_id=%s room_id=%s participant_id=%s",
 
1168
  actual_duration_ms = max(duration_ms, 0.0) if duration_ms is not None else None
1169
  end_ns = start_ns + self._duration_ms_to_ns(actual_duration_ms or 0.0)
1170
 
1171
+ span = tracer.start_span(name, context=context, start_time=start_ns)
1172
+ try:
1173
  if actual_duration_ms is not None:
1174
  span.set_attribute("duration_ms", actual_duration_ms)
1175
  if observation_input is not None:
 
1182
  if value is None:
1183
  continue
1184
  span.set_attribute(key, value)
1185
+ finally:
1186
  self._close_span_at(span, end_ns)
1187
  return end_ns
1188
 
 
1276
  return None
1277
  return f"{normalized_prefix}_{uuid.uuid4()}"
1278
 
1279
+ def _build_fallback_participant_id(self, prefix: Optional[str]) -> Optional[str]:
1280
+ normalized_prefix = self._normalize_optional_text(prefix)
1281
+ if not normalized_prefix:
1282
+ return None
1283
+ return f"{normalized_prefix}_{uuid.uuid4()}"
1284
+
1285
  async def _resolve_room_id(self) -> str:
1286
  if self._room_id and self._room_id != self._room_name:
1287
  return self._room_id
src/core/settings.py CHANGED
@@ -91,13 +91,13 @@ class VoiceSettings(CoreSettings):
91
 
92
  # Voice Activity Detection Settings
93
  VAD_MIN_SPEECH_DURATION: float = Field(
94
- default=0.25,
95
  ge=0.1,
96
  le=1.0,
97
  description="Minimum speech duration (seconds) before VAD activation",
98
  )
99
  VAD_MIN_SILENCE_DURATION: float = Field(
100
- default=0.5,
101
  ge=0.1,
102
  le=2.0,
103
  description="Minimum silence duration (seconds) before VAD deactivation",
 
91
 
92
  # Voice Activity Detection Settings
93
  VAD_MIN_SPEECH_DURATION: float = Field(
94
+ default=0.18,
95
  ge=0.1,
96
  le=1.0,
97
  description="Minimum speech duration (seconds) before VAD activation",
98
  )
99
  VAD_MIN_SILENCE_DURATION: float = Field(
100
+ default=0.30,
101
  ge=0.1,
102
  le=2.0,
103
  description="Minimum silence duration (seconds) before VAD deactivation",
src/plugins/pocket_tts/tts.py CHANGED
@@ -4,6 +4,7 @@ import asyncio
4
  import contextlib
5
  import logging
6
  import queue
 
7
  import time
8
  from collections.abc import AsyncIterator
9
  from dataclasses import dataclass
@@ -27,6 +28,12 @@ logging.getLogger("pocket_tts.conditioners.text").setLevel(logging.WARNING)
27
 
28
  DEFAULT_VOICE = "alba"
29
  NATIVE_SAMPLE_RATE = 24000
 
 
 
 
 
 
30
 
31
 
32
  class TTSMetricsCallback(Protocol):
@@ -218,6 +225,13 @@ class PocketTTS(tts.TTS):
218
  audio_duration = _bytes_to_duration(total_bytes=total_bytes, sample_rate=self.sample_rate)
219
  return first_chunk_ttfb, generation_duration, audio_duration
220
 
 
 
 
 
 
 
 
221
 
222
  class PocketChunkedStream(tts.ChunkedStream):
223
  def __init__(
@@ -240,19 +254,32 @@ class PocketChunkedStream(tts.ChunkedStream):
240
  stream=False,
241
  )
242
 
243
- (
244
- first_chunk_ttfb,
245
- generation_duration,
246
- audio_duration,
247
- ) = await pocket_tts._push_generated_audio(
248
- text=self._input_text,
249
- conn_options=self._conn_options,
250
- output_emitter=output_emitter,
251
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  output_emitter.flush()
254
 
255
- if pocket_tts._metrics_callback:
256
  pocket_tts._metrics_callback(
257
  ttfb=first_chunk_ttfb,
258
  duration=generation_duration,
@@ -290,33 +317,130 @@ class PocketSynthesizeStream(tts.SynthesizeStream):
290
  async def _flush_text_buffer(
291
  self, *, text_buffer: str, output_emitter: tts.AudioEmitter
292
  ) -> None:
293
- if not text_buffer.strip():
 
 
294
  return
295
 
296
- segment_id = shortuuid("SEG_")
297
- output_emitter.start_segment(segment_id=segment_id)
298
- await self._synthesize_segment(text_buffer, output_emitter)
299
- output_emitter.end_segment()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
- async def _synthesize_segment(self, text: str, output_emitter: tts.AudioEmitter) -> None:
 
 
 
 
 
 
 
 
 
302
  self._mark_started()
303
  pocket_tts = cast(PocketTTS, self._tts)
304
- (
305
- first_chunk_ttfb,
306
- generation_duration,
307
- audio_duration,
308
- ) = await pocket_tts._push_generated_audio(
309
  text=text,
310
  conn_options=self._conn_options,
311
  output_emitter=output_emitter,
312
  )
313
 
314
- if pocket_tts._metrics_callback:
315
- pocket_tts._metrics_callback(
316
- ttfb=first_chunk_ttfb,
317
- duration=generation_duration,
318
- audio_duration=audio_duration,
319
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
 
322
  def _tensor_to_pcm_bytes(
 
4
  import contextlib
5
  import logging
6
  import queue
7
+ import re
8
  import time
9
  from collections.abc import AsyncIterator
10
  from dataclasses import dataclass
 
28
 
29
  DEFAULT_VOICE = "alba"
30
  NATIVE_SAMPLE_RATE = 24000
31
+ MAX_TTS_SEGMENT_CHARS = 220
32
+
33
+ _BULLET_PREFIX_RE = re.compile(r"^\s*(?:[-*+]|(?:\d+[\.\)]))\s+")
34
+ _MARKDOWN_LINK_RE = re.compile(r"\[([^\]]+)\]\((?:[^)]+)\)")
35
+ _SENTENCE_SPLIT_RE = re.compile(r"(?<=[.!?])\s+")
36
+ _WHITESPACE_RE = re.compile(r"\s+")
37
 
38
 
39
  class TTSMetricsCallback(Protocol):
 
225
  audio_duration = _bytes_to_duration(total_bytes=total_bytes, sample_rate=self.sample_rate)
226
  return first_chunk_ttfb, generation_duration, audio_duration
227
 
228
+ def _prepare_text_segments(self, text: str) -> list[str]:
229
+ """Normalize text for TTS and split into short chunks for lower tail latency."""
230
+ cleaned = _sanitize_tts_text(text)
231
+ if not cleaned:
232
+ return []
233
+ return _chunk_tts_text(cleaned, max_chars=MAX_TTS_SEGMENT_CHARS)
234
+
235
 
236
  class PocketChunkedStream(tts.ChunkedStream):
237
  def __init__(
 
254
  stream=False,
255
  )
256
 
257
+ text_segments = pocket_tts._prepare_text_segments(self._input_text)
258
+ if not text_segments:
259
+ output_emitter.flush()
260
+ return
261
+
262
+ first_chunk_ttfb = -1.0
263
+ generation_duration = 0.0
264
+ audio_duration = 0.0
265
+ for text_segment in text_segments:
266
+ (
267
+ segment_ttfb,
268
+ segment_duration,
269
+ segment_audio_duration,
270
+ ) = await pocket_tts._push_generated_audio(
271
+ text=text_segment,
272
+ conn_options=self._conn_options,
273
+ output_emitter=output_emitter,
274
+ )
275
+ if first_chunk_ttfb < 0 and segment_ttfb >= 0:
276
+ first_chunk_ttfb = segment_ttfb
277
+ generation_duration += segment_duration
278
+ audio_duration += segment_audio_duration
279
 
280
  output_emitter.flush()
281
 
282
+ if pocket_tts._metrics_callback and first_chunk_ttfb >= 0:
283
  pocket_tts._metrics_callback(
284
  ttfb=first_chunk_ttfb,
285
  duration=generation_duration,
 
317
  async def _flush_text_buffer(
318
  self, *, text_buffer: str, output_emitter: tts.AudioEmitter
319
  ) -> None:
320
+ pocket_tts = cast(PocketTTS, self._tts)
321
+ text_segments = pocket_tts._prepare_text_segments(text_buffer)
322
+ if not text_segments:
323
  return
324
 
325
+ # LiveKit expects one segment per flushed text buffer in streaming mode.
326
+ output_emitter.start_segment(segment_id=shortuuid("SEG_"))
327
+ first_chunk_ttfb = -1.0
328
+ generation_duration = 0.0
329
+ audio_duration = 0.0
330
+ try:
331
+ for text_segment in text_segments:
332
+ (
333
+ segment_ttfb,
334
+ segment_duration,
335
+ segment_audio_duration,
336
+ ) = await self._synthesize_segment(text_segment, output_emitter)
337
+ if first_chunk_ttfb < 0 and segment_ttfb >= 0:
338
+ first_chunk_ttfb = segment_ttfb
339
+ generation_duration += segment_duration
340
+ audio_duration += segment_audio_duration
341
+ finally:
342
+ output_emitter.end_segment()
343
 
344
+ if pocket_tts._metrics_callback and first_chunk_ttfb >= 0:
345
+ pocket_tts._metrics_callback(
346
+ ttfb=first_chunk_ttfb,
347
+ duration=generation_duration,
348
+ audio_duration=audio_duration,
349
+ )
350
+
351
+ async def _synthesize_segment(
352
+ self, text: str, output_emitter: tts.AudioEmitter
353
+ ) -> tuple[float, float, float]:
354
  self._mark_started()
355
  pocket_tts = cast(PocketTTS, self._tts)
356
+ return await pocket_tts._push_generated_audio(
 
 
 
 
357
  text=text,
358
  conn_options=self._conn_options,
359
  output_emitter=output_emitter,
360
  )
361
 
362
+
363
+ def _sanitize_tts_text(text: str) -> str:
364
+ if not text:
365
+ return ""
366
+
367
+ normalized = text.replace("\r\n", "\n").replace("\r", "\n")
368
+ normalized = _MARKDOWN_LINK_RE.sub(r"\1", normalized)
369
+
370
+ cleaned_lines: list[str] = []
371
+ for raw_line in normalized.split("\n"):
372
+ line = raw_line.strip()
373
+ if not line:
374
+ continue
375
+ line = _BULLET_PREFIX_RE.sub("", line)
376
+ line = line.lstrip("#> ").strip()
377
+ line = line.replace("**", "")
378
+ line = line.replace("__", "")
379
+ line = line.replace("`", "")
380
+ line = line.replace("*", "")
381
+ line = line.replace("|", " ")
382
+ cleaned_lines.append(line)
383
+
384
+ cleaned = " ".join(cleaned_lines)
385
+ cleaned = _WHITESPACE_RE.sub(" ", cleaned).strip()
386
+ return cleaned
387
+
388
+
389
+ def _chunk_tts_text(text: str, *, max_chars: int) -> list[str]:
390
+ if not text.strip():
391
+ return []
392
+ if len(text) <= max_chars:
393
+ return [text]
394
+
395
+ sentences = [s.strip() for s in _SENTENCE_SPLIT_RE.split(text) if s.strip()]
396
+ if not sentences:
397
+ sentences = [text.strip()]
398
+
399
+ chunks: list[str] = []
400
+ current = ""
401
+ for sentence in sentences:
402
+ for sentence_part in _split_overlong_text(sentence, max_chars=max_chars):
403
+ if not current:
404
+ current = sentence_part
405
+ continue
406
+ candidate = f"{current} {sentence_part}"
407
+ if len(candidate) <= max_chars:
408
+ current = candidate
409
+ else:
410
+ chunks.append(current)
411
+ current = sentence_part
412
+
413
+ if current:
414
+ chunks.append(current)
415
+ return chunks
416
+
417
+
418
+ def _split_overlong_text(text: str, *, max_chars: int) -> list[str]:
419
+ if len(text) <= max_chars:
420
+ return [text]
421
+
422
+ words = text.split()
423
+ if not words:
424
+ return []
425
+
426
+ chunks: list[str] = []
427
+ current_words: list[str] = []
428
+ current_len = 0
429
+ for word in words:
430
+ additional_len = len(word) if not current_words else len(word) + 1
431
+ if current_words and current_len + additional_len > max_chars:
432
+ chunks.append(" ".join(current_words))
433
+ current_words = [word]
434
+ current_len = len(word)
435
+ continue
436
+
437
+ current_words.append(word)
438
+ current_len += additional_len
439
+
440
+ if current_words:
441
+ chunks.append(" ".join(current_words))
442
+
443
+ return chunks
444
 
445
 
446
  def _tensor_to_pcm_bytes(
tests/test_langfuse_turn_tracing.py CHANGED
@@ -8,6 +8,7 @@ from typing import Any
8
 
9
  import pytest
10
 
 
11
  from livekit.agents import metrics
12
 
13
  from src.agent.metrics_collector import MetricsCollector
@@ -23,6 +24,7 @@ class _FakeSpan:
23
  name: str
24
  trace_id: int
25
  attributes: dict[str, Any] = field(default_factory=dict)
 
26
 
27
  def set_attribute(self, key: str, value: Any) -> None:
28
  self.attributes[key] = value
@@ -30,6 +32,10 @@ class _FakeSpan:
30
  def get_span_context(self) -> _FakeSpanContext:
31
  return _FakeSpanContext(trace_id=self.trace_id)
32
 
 
 
 
 
33
 
34
  class _FakeTracer:
35
  def __init__(self) -> None:
@@ -37,6 +43,22 @@ class _FakeTracer:
37
  self._stack: list[_FakeSpan] = []
38
  self._next_trace_id = 1
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @contextmanager
41
  def start_as_current_span(self, name: str, **_: Any): # type: ignore[no-untyped-def]
42
  if self._stack:
@@ -55,6 +77,9 @@ class _FakeTracer:
55
 
56
 
57
  class _BrokenTracer:
 
 
 
58
  @contextmanager
59
  def start_as_current_span(self, name: str, **_: Any): # type: ignore[no-untyped-def]
60
  raise RuntimeError(f"broken tracer for {name}")
@@ -255,6 +280,7 @@ def test_turn_trace_has_required_metadata_and_spans(monkeypatch: pytest.MonkeyPa
255
  assert conversational_span.attributes["stt_finalization_ms"] == pytest.approx(250.0)
256
  assert conversational_span.attributes["llm_ttft_ms"] > 0
257
  assert conversational_span.attributes["tts_ttfb_ms"] > 0
 
258
 
259
  payloads = _decode_payloads(room)
260
  trace_updates = [payload for payload in payloads if payload.get("type") == "trace_update"]
@@ -566,3 +592,78 @@ def test_real_session_metadata_overrides_fallback_for_pending_turns(
566
  turn_spans = [span for span in fake_tracer.spans if span.name == "turn"]
567
  assert len(turn_spans) == 1
568
  assert turn_spans[0].attributes["session_id"] == "session-real"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  import pytest
10
 
11
+ from opentelemetry import trace as otel_trace
12
  from livekit.agents import metrics
13
 
14
  from src.agent.metrics_collector import MetricsCollector
 
24
  name: str
25
  trace_id: int
26
  attributes: dict[str, Any] = field(default_factory=dict)
27
+ end_count: int = 0
28
 
29
  def set_attribute(self, key: str, value: Any) -> None:
30
  self.attributes[key] = value
 
32
  def get_span_context(self) -> _FakeSpanContext:
33
  return _FakeSpanContext(trace_id=self.trace_id)
34
 
35
+ def end(self, end_time: Any = None) -> None:
36
+ _ = end_time
37
+ self.end_count += 1
38
+
39
 
40
  class _FakeTracer:
41
  def __init__(self) -> None:
 
43
  self._stack: list[_FakeSpan] = []
44
  self._next_trace_id = 1
45
 
46
+ def start_span(self, name: str, **kwargs: Any) -> _FakeSpan:
47
+ trace_id = None
48
+ context = kwargs.get("context")
49
+ if context is not None:
50
+ parent_span = otel_trace.get_current_span(context)
51
+ get_span_context = getattr(parent_span, "get_span_context", None)
52
+ if callable(get_span_context):
53
+ trace_id = get_span_context().trace_id
54
+ if not trace_id:
55
+ trace_id = self._next_trace_id
56
+ self._next_trace_id += 1
57
+
58
+ span = _FakeSpan(name=name, trace_id=trace_id)
59
+ self.spans.append(span)
60
+ return span
61
+
62
  @contextmanager
63
  def start_as_current_span(self, name: str, **_: Any): # type: ignore[no-untyped-def]
64
  if self._stack:
 
77
 
78
 
79
  class _BrokenTracer:
80
+ def start_span(self, name: str, **_: Any) -> Any:
81
+ raise RuntimeError(f"broken tracer for {name}")
82
+
83
  @contextmanager
84
  def start_as_current_span(self, name: str, **_: Any): # type: ignore[no-untyped-def]
85
  raise RuntimeError(f"broken tracer for {name}")
 
280
  assert conversational_span.attributes["stt_finalization_ms"] == pytest.approx(250.0)
281
  assert conversational_span.attributes["llm_ttft_ms"] > 0
282
  assert conversational_span.attributes["tts_ttfb_ms"] > 0
283
+ assert all(span.end_count == 1 for span in fake_tracer.spans)
284
 
285
  payloads = _decode_payloads(room)
286
  trace_updates = [payload for payload in payloads if payload.get("type") == "trace_update"]
 
592
  turn_spans = [span for span in fake_tracer.spans if span.name == "turn"]
593
  assert len(turn_spans) == 1
594
  assert turn_spans[0].attributes["session_id"] == "session-real"
595
+
596
+
597
+ def test_fallback_console_participant_id_is_used_when_metadata_absent(
598
+ monkeypatch: pytest.MonkeyPatch,
599
+ ) -> None:
600
+ import src.agent.metrics_collector as metrics_collector_module
601
+
602
+ fake_tracer = _FakeTracer()
603
+ monkeypatch.setattr(metrics_collector_module, "tracer", fake_tracer)
604
+
605
+ room = _FakeRoom()
606
+ collector = MetricsCollector(
607
+ room=room, # type: ignore[arg-type]
608
+ model_name="moonshine",
609
+ room_name=room.name,
610
+ room_id="RM123",
611
+ participant_id=None,
612
+ fallback_session_prefix="console",
613
+ fallback_participant_prefix="console",
614
+ langfuse_enabled=True,
615
+ )
616
+
617
+ async def _run() -> None:
618
+ await collector.on_user_input_transcribed("console participant", is_final=True)
619
+ await collector.on_metrics_collected(_make_llm_metrics("speech-console-participant"))
620
+ await collector.on_conversation_item_added(role="assistant", content="ok")
621
+ await collector.on_metrics_collected(_make_tts_metrics("speech-console-participant"))
622
+ await collector.wait_for_pending_trace_tasks()
623
+
624
+ asyncio.run(_run())
625
+
626
+ turn_spans = [span for span in fake_tracer.spans if span.name == "turn"]
627
+ assert len(turn_spans) == 1
628
+ participant_id = turn_spans[0].attributes["participant_id"]
629
+ assert participant_id.startswith("console_")
630
+ assert participant_id != "unknown-participant"
631
+
632
+
633
+ def test_real_participant_metadata_overrides_fallback_for_pending_turns(
634
+ monkeypatch: pytest.MonkeyPatch,
635
+ ) -> None:
636
+ import src.agent.metrics_collector as metrics_collector_module
637
+
638
+ fake_tracer = _FakeTracer()
639
+ monkeypatch.setattr(metrics_collector_module, "tracer", fake_tracer)
640
+
641
+ room = _FakeRoom()
642
+ collector = MetricsCollector(
643
+ room=room, # type: ignore[arg-type]
644
+ model_name="moonshine",
645
+ room_name=room.name,
646
+ room_id="RM123",
647
+ participant_id=None,
648
+ fallback_session_prefix="console",
649
+ fallback_participant_prefix="console",
650
+ langfuse_enabled=True,
651
+ )
652
+
653
+ async def _run() -> None:
654
+ await collector.on_user_input_transcribed("override participant", is_final=True)
655
+ await collector.on_metrics_collected(_make_llm_metrics("speech-override-participant"))
656
+ await collector.on_metrics_collected(_make_tts_metrics("speech-override-participant"))
657
+ await collector.on_session_metadata(
658
+ session_id="session-real-participant",
659
+ participant_id="web-real-participant",
660
+ )
661
+ await collector.on_conversation_item_added(role="assistant", content="reply")
662
+ await collector.wait_for_pending_trace_tasks()
663
+
664
+ asyncio.run(_run())
665
+
666
+ turn_spans = [span for span in fake_tracer.spans if span.name == "turn"]
667
+ assert len(turn_spans) == 1
668
+ assert turn_spans[0].attributes["session_id"] == "session-real-participant"
669
+ assert turn_spans[0].attributes["participant_id"] == "web-real-participant"
tests/test_pocket_tts_plugin.py CHANGED
@@ -27,6 +27,7 @@ def pocket_plugin(monkeypatch: pytest.MonkeyPatch) -> Any:
27
  "raise_on_generate": None,
28
  "active_generations": 0,
29
  "max_active_generations": 0,
 
30
  }
31
 
32
  class _FakeModel:
@@ -46,6 +47,7 @@ def pocket_plugin(monkeypatch: pytest.MonkeyPatch) -> Any:
46
  ) -> Generator[np.ndarray[Any, np.dtype[np.float32]], None, None]:
47
  calls["state"] = state
48
  calls["text"] = text
 
49
  calls["copy_state"] = copy_state
50
  calls["active_generations"] += 1
51
  calls["max_active_generations"] = max(
@@ -203,6 +205,30 @@ def test_stream_emits_before_generation_completes(pocket_plugin: Any) -> None:
203
  asyncio.run(_run())
204
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def test_chunked_generation_serializes_concurrent_requests(pocket_plugin: Any) -> None:
207
  module = pocket_plugin["module"]
208
  pocket_plugin["per_chunk_sleep"] = 0.03
@@ -253,5 +279,55 @@ def test_generation_timeout_is_mapped_to_api_timeout_error(pocket_plugin: Any) -
253
  gate.set()
254
 
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  async def _collect_events(stream: Any) -> list[Any]:
257
  return [event async for event in stream]
 
27
  "raise_on_generate": None,
28
  "active_generations": 0,
29
  "max_active_generations": 0,
30
+ "texts": [],
31
  }
32
 
33
  class _FakeModel:
 
47
  ) -> Generator[np.ndarray[Any, np.dtype[np.float32]], None, None]:
48
  calls["state"] = state
49
  calls["text"] = text
50
+ calls["texts"].append(text)
51
  calls["copy_state"] = copy_state
52
  calls["active_generations"] += 1
53
  calls["max_active_generations"] = max(
 
205
  asyncio.run(_run())
206
 
207
 
208
+ def test_stream_uses_single_segment_for_one_flush(pocket_plugin: Any) -> None:
209
+ module = pocket_plugin["module"]
210
+ tts_v = module.PocketTTS(voice="alba")
211
+ long_text = (
212
+ "First sentence with enough words to trigger internal chunking. " * 6
213
+ + "Second sentence also long enough to split. " * 6
214
+ )
215
+
216
+ async def _run() -> None:
217
+ async with tts_v.stream() as synth_stream:
218
+ synth_stream.push_text(long_text)
219
+ synth_stream.end_input()
220
+ events = await asyncio.wait_for(_collect_events(synth_stream), timeout=3.0)
221
+
222
+ segment_ids = {
223
+ event.segment_id
224
+ for event in events
225
+ if not event.is_final and isinstance(event.segment_id, str) and event.segment_id
226
+ }
227
+ assert len(segment_ids) == 1
228
+
229
+ asyncio.run(_run())
230
+
231
+
232
  def test_chunked_generation_serializes_concurrent_requests(pocket_plugin: Any) -> None:
233
  module = pocket_plugin["module"]
234
  pocket_plugin["per_chunk_sleep"] = 0.03
 
279
  gate.set()
280
 
281
 
282
+ def test_sanitize_tts_text_removes_markdown_noise(pocket_plugin: Any) -> None:
283
+ module = pocket_plugin["module"]
284
+ raw_text = """
285
+ ## Title
286
+ - **Bold** item with [link text](https://example.com)
287
+ 1. `code` item
288
+ """
289
+
290
+ sanitized = module._sanitize_tts_text(raw_text)
291
+ assert "##" not in sanitized
292
+ assert "**" not in sanitized
293
+ assert "`" not in sanitized
294
+ assert "[link text]" not in sanitized
295
+ assert "(https://example.com)" not in sanitized
296
+ assert "link text" in sanitized
297
+ assert "Bold item with" in sanitized
298
+
299
+
300
+ def test_chunk_tts_text_respects_length_limit(pocket_plugin: Any) -> None:
301
+ module = pocket_plugin["module"]
302
+ text = " ".join(["word"] * 80)
303
+
304
+ chunks = module._chunk_tts_text(text, max_chars=40)
305
+ assert len(chunks) > 1
306
+ assert all(len(chunk) <= 40 for chunk in chunks)
307
+ assert " ".join(chunks).replace(" ", " ").strip() == text
308
+
309
+
310
+ def test_chunked_synthesize_sanitizes_and_splits_long_text(pocket_plugin: Any) -> None:
311
+ module = pocket_plugin["module"]
312
+ tts_v = module.PocketTTS(voice="alba")
313
+ text = (
314
+ "## Header\n"
315
+ "- **First** item with [a link](https://example.com).\n"
316
+ + "Second sentence keeps going with enough words to exceed the segment limit. " * 5
317
+ + "Third sentence keeps going with enough words to exceed the segment limit. " * 5
318
+ )
319
+
320
+ async def _run() -> None:
321
+ await _collect_events(tts_v.synthesize(text))
322
+
323
+ asyncio.run(_run())
324
+
325
+ generated_texts = pocket_plugin["texts"]
326
+ assert len(generated_texts) >= 2
327
+ assert all(len(part) <= module.MAX_TTS_SEGMENT_CHARS for part in generated_texts)
328
+ assert all("**" not in part and "`" not in part for part in generated_texts)
329
+ assert all("https://example.com" not in part for part in generated_texts)
330
+
331
+
332
  async def _collect_events(stream: Any) -> list[Any]:
333
  return [event async for event in stream]