voice-agent / app /speech.py
RalphThings's picture
Deploy Hugging Face Space
5f0a2ac
from __future__ import annotations
import asyncio
import gc
import json
import importlib
import os
import re
import shutil
import subprocess
import sys
import tempfile
from collections.abc import AsyncIterator
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
import httpx
import numpy as np
from app.audio import wav_bytes_from_float32
from app.config import settings
PARALINGUISTIC_TAG_PATTERN = re.compile(r"\[(laugh|chuckle|sigh|cough)\]", flags=re.I)
ANSI_ESCAPE_PATTERN = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
THINK_BLOCK_PATTERN = re.compile(r"<think>.*?</think>", flags=re.I | re.S)
THINK_TAG_PATTERN = re.compile(r"</?think>", flags=re.I)
ROLE_CONTINUATION_PATTERN = re.compile(r"(?i)\b(?:human|user|assistant|system)\s*:\s*")
META_TAIL_PATTERN = re.compile(
r"(?is)\b(?:thought|analysis|reasoning|explanation|note)\s*:\s*.*$|"
r"\bthis response\b.*$|"
r"\banswering in a conversational tone\b.*$|"
r"\byou can add more\b.*$|"
r"\bwithout any unnecessary elaboration\b.*$"
)
CLI_PROMPT_MARKER = "›"
INTERNAL_REPLY_PATTERNS = [
re.compile(r"💭\s*Injected relevant context from memory\.?", flags=re.I),
re.compile(r"Injected relevant context from memory\.?", flags=re.I),
re.compile(r"Relevant context from memory\.?", flags=re.I),
re.compile(r"Context from memory\.?", flags=re.I),
]
CHATTERBOX_ONNX_SAMPLE_RATE = 24000
CHATTERBOX_ONNX_START_SPEECH_TOKEN = 6561
CHATTERBOX_ONNX_STOP_SPEECH_TOKEN = 6562
CHATTERBOX_ONNX_SILENCE_TOKEN = 4299
CHATTERBOX_ONNX_NUM_KV_HEADS = 16
CHATTERBOX_ONNX_HEAD_DIM = 64
KOKORO_SAMPLE_RATE = 24000
@dataclass(frozen=True)
class TranscriptionResult:
text: str
language: str | None = None
language_probability: float | None = None
backend: str = ""
class RepetitionPenaltyLogitsProcessor:
def __init__(self, penalty: float):
if penalty <= 0:
raise ValueError("penalty must be > 0")
self.penalty = float(penalty)
def __call__(self, input_ids: np.ndarray, scores: np.ndarray) -> np.ndarray:
score = np.take_along_axis(scores, input_ids, axis=1)
score = np.where(score < 0, score * self.penalty, score / self.penalty)
scores_processed = scores.copy()
np.put_along_axis(scores_processed, input_ids, score, axis=1)
return scores_processed
class ChatterboxOnnxTTS:
def __init__(self) -> None:
self._model_id = settings.chatterbox_onnx_model_id
self._dtype = settings.chatterbox_onnx_dtype
self._provider = settings.chatterbox_onnx_provider
self._ort = self._load_onnxruntime()
self._AutoTokenizer = importlib.import_module("transformers").AutoTokenizer
self._hf_hub_download = importlib.import_module("huggingface_hub").hf_hub_download
self._librosa = importlib.import_module("librosa")
self._soundfile = importlib.import_module("soundfile")
self._tokenizer = self._AutoTokenizer.from_pretrained(self._model_id)
self._repetition_penalty = RepetitionPenaltyLogitsProcessor(settings.chatterbox_onnx_repetition_penalty)
self._sessions = self._load_sessions()
def _load_onnxruntime(self):
extra_path = settings.chatterbox_onnx_site_packages_path.strip()
# Import torch first so the working project CUDA stack is initialized before ONNX Runtime.
try:
import torch # noqa: F401
except Exception:
pass
try:
ort = importlib.import_module("onnxruntime")
except Exception:
if not extra_path:
raise
if extra_path not in sys.path:
sys.path.insert(0, extra_path)
ort = importlib.import_module("onnxruntime")
if self._provider == "cuda":
providers = ort.get_available_providers()
if "CUDAExecutionProvider" not in providers:
raise RuntimeError(
"onnxruntime does not expose CUDAExecutionProvider in the project environment; "
f"available providers: {providers}"
)
ort.preload_dlls()
return ort
def _filename_for(self, name: str) -> str:
if self._dtype == "fp32":
return f"{name}.onnx"
if self._dtype == "q8":
return f"{name}_quantized.onnx"
return f"{name}_{self._dtype}.onnx"
def _download_graph(self, name: str) -> str:
filename = self._filename_for(name)
graph = self._hf_hub_download(self._model_id, subfolder="onnx", filename=filename)
self._hf_hub_download(self._model_id, subfolder="onnx", filename=f"{filename}_data")
return graph
def _make_session(self, path: str):
providers = ["CUDAExecutionProvider"] if self._provider == "cuda" else ["CPUExecutionProvider"]
return self._ort.InferenceSession(path, providers=providers)
def _load_sessions(self) -> dict[str, object]:
return {
"speech_encoder": self._make_session(self._download_graph("speech_encoder")),
"embed_tokens": self._make_session(self._download_graph("embed_tokens")),
"language_model": self._make_session(self._download_graph("language_model")),
"conditional_decoder": self._make_session(self._download_graph("conditional_decoder")),
}
def _resolve_voice_path(self, audio_prompt_path: str | None) -> str:
candidate = audio_prompt_path or settings.chatterbox_onnx_voice_path
if not candidate:
raise ValueError("No voice reference path configured for Chatterbox ONNX")
if not Path(candidate).is_file():
raise FileNotFoundError(f"Voice reference not found: {candidate}")
return candidate
def generate(self, text: str, audio_prompt_path: str | None = None) -> np.ndarray:
voice_path = self._resolve_voice_path(audio_prompt_path)
audio_values, _ = self._librosa.load(voice_path, sr=CHATTERBOX_ONNX_SAMPLE_RATE)
audio_values = audio_values[np.newaxis, :].astype(np.float32)
input_ids = self._tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
generate_tokens = np.array([[CHATTERBOX_ONNX_START_SPEECH_TOKEN]], dtype=np.int64)
speech_encoder_session = self._sessions["speech_encoder"]
embed_tokens_session = self._sessions["embed_tokens"]
language_model_session = self._sessions["language_model"]
cond_decoder_session = self._sessions["conditional_decoder"]
for i in range(settings.chatterbox_onnx_max_new_tokens):
inputs_embeds = embed_tokens_session.run(None, {"input_ids": input_ids})[0]
if i == 0:
cond_emb, prompt_token, speaker_embeddings, speaker_features = speech_encoder_session.run(
None, {"audio_values": audio_values}
)
inputs_embeds = np.concatenate((cond_emb, inputs_embeds), axis=1)
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = {
item.name: np.zeros(
[batch_size, CHATTERBOX_ONNX_NUM_KV_HEADS, 0, CHATTERBOX_ONNX_HEAD_DIM],
dtype=np.float16 if item.type == "tensor(float16)" else np.float32,
)
for item in language_model_session.get_inputs()
if "past_key_values" in item.name
}
attention_mask = np.ones((batch_size, seq_len), dtype=np.int64)
position_ids = np.arange(seq_len, dtype=np.int64).reshape(1, -1).repeat(batch_size, axis=0)
logits, *present_key_values = language_model_session.run(
None,
{
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"position_ids": position_ids,
**past_key_values,
},
)
logits = logits[:, -1, :]
next_token_logits = self._repetition_penalty(generate_tokens, logits)
input_ids = np.argmax(next_token_logits, axis=-1, keepdims=True).astype(np.int64)
generate_tokens = np.concatenate((generate_tokens, input_ids), axis=-1)
if (input_ids.flatten() == CHATTERBOX_ONNX_STOP_SPEECH_TOKEN).all():
break
attention_mask = np.concatenate([attention_mask, np.ones((batch_size, 1), dtype=np.int64)], axis=1)
position_ids = position_ids[:, -1:] + 1
for j, key in enumerate(past_key_values):
past_key_values[key] = present_key_values[j]
speech_tokens = generate_tokens[:, 1:-1]
silence_tokens = np.full((speech_tokens.shape[0], 3), CHATTERBOX_ONNX_SILENCE_TOKEN, dtype=np.int64)
speech_tokens = np.concatenate([prompt_token, speech_tokens, silence_tokens], axis=1)
wav = cond_decoder_session.run(
None,
{
"speech_tokens": speech_tokens,
"speaker_embeddings": speaker_embeddings,
"speaker_features": speaker_features,
},
)[0].squeeze(axis=0)
return np.asarray(wav, dtype=np.float32)
class KokoroTTS:
def __init__(self) -> None:
kokoro_module = importlib.import_module("kokoro")
self._pipeline = kokoro_module.KPipeline(
lang_code=settings.kokoro_lang_code,
repo_id=settings.kokoro_repo_id,
device=self._resolve_device(),
)
self._voice = settings.kokoro_voice
self._speed = settings.kokoro_speed
self._provider = self._resolve_device()
self._model_id = settings.kokoro_repo_id
self._dtype = "fp32"
def _resolve_device(self) -> str:
device, fallback_reason = SpeechPipeline._resolve_torch_device_static(
settings.kokoro_device,
component="Kokoro",
)
if fallback_reason:
print(fallback_reason, file=sys.stderr)
if device != settings.kokoro_device:
print(
f"Kokoro fallback: using device={device} instead of {settings.kokoro_device}",
file=sys.stderr,
)
return device
def generate(self, text: str, audio_prompt_path: str | None = None) -> np.ndarray:
del audio_prompt_path
chunks: list[np.ndarray] = []
for result in self._pipeline(text, voice=self._voice, speed=self._speed):
audio = result.audio
if audio is None:
continue
if hasattr(audio, "detach"):
audio = audio.detach().cpu().numpy()
chunks.append(np.asarray(audio, dtype=np.float32).flatten())
if not chunks:
return np.zeros(0, dtype=np.float32)
return np.concatenate(chunks).astype(np.float32, copy=False)
class PersistentAgentCliSession:
def __init__(self, model: str | None) -> None:
self._model = model
self._request_lock = asyncio.Lock()
self._last_session_id: str | None = None
async def stream_events(self, transcript: str, session_id: str | None = None) -> AsyncIterator[dict]:
cleaned = transcript.strip()
if not cleaned:
return
prompt_text = self._build_voice_prompt(cleaned)
async with self._request_lock:
command = [self._resolve_my_agent_command()]
if settings.my_agent_force:
command.append("--force")
if settings.my_agent_cwd:
command.extend(["--cwd", settings.my_agent_cwd])
if self._model:
command.extend(["--model", self._model])
command.extend(["--stream-json", "--prompt", prompt_text])
active_session_id = session_id or self._last_session_id
if active_session_id:
command.extend(["--session", active_session_id])
process = await asyncio.create_subprocess_exec(
*command,
cwd=settings.my_agent_cwd or None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
queue: asyncio.Queue[tuple[str, str | None]] = asyncio.Queue()
stdout_task = asyncio.create_task(self._read_stream(process.stdout, "stdout", queue))
stderr_task = asyncio.create_task(self._read_stream(process.stderr, "stderr", queue))
saw_error = False
try:
open_streams = 2
while open_streams > 0:
source, line = await queue.get()
if line is None:
open_streams -= 1
continue
event = self._parse_stream_line(source, line, prompt_text)
if event is None:
continue
if event.get("kind") == "session_created":
self._last_session_id = (
event.get("newSessionId")
or event.get("session_id")
or self._last_session_id
)
if event.get("kind") == "error":
saw_error = True
yield event
return_code = await process.wait()
if return_code != 0 and not saw_error:
yield {"kind": "error", "content": f"my-agent prompt run exited with code {return_code}"}
finally:
stdout_task.cancel()
stderr_task.cancel()
if process.returncode is None:
process.terminate()
await process.wait()
def _resolve_my_agent_command(self) -> str:
configured = (settings.my_agent_command or "").strip()
candidates: list[str] = []
if configured:
candidates.append(configured)
cargo_install = "/home/rapheal/.cargo/bin/my-agent"
if cargo_install not in candidates:
candidates.append(cargo_install)
which_path = shutil.which("my-agent")
if which_path and which_path not in candidates:
candidates.append(which_path)
for candidate in candidates:
expanded = Path(candidate).expanduser()
if expanded.is_file() and os.access(expanded, os.X_OK):
return str(expanded)
return configured or cargo_install
def reset_session(self) -> None:
self._last_session_id = None
async def _read_stream(
self,
stream: asyncio.StreamReader | None,
source: str,
queue: asyncio.Queue[tuple[str, str | None]],
) -> None:
if stream is None:
await queue.put((source, None))
return
try:
while True:
raw_line = await stream.readline()
if not raw_line:
break
await queue.put((source, raw_line.decode("utf-8", errors="replace").rstrip()))
finally:
await queue.put((source, None))
def _parse_stream_line(self, source: str, line: str, transcript: str) -> dict | None:
clean = self._strip_cli_formatting(line)
if not clean:
return None
if source == "stderr":
return self._normalize_stderr_line(clean)
return self._normalize_json_line(clean, transcript)
def _normalize_json_line(self, line: str, transcript: str) -> dict | None:
try:
payload = json.loads(line)
except json.JSONDecodeError:
return self._normalize_text_line(line, transcript)
if not isinstance(payload, dict):
return None
kind = str(payload.get("kind") or "").strip()
role = str(payload.get("role") or "").strip().lower()
if kind == "thinking":
return {"kind": "status", "text": payload.get("content") or "thinking"}
if kind in {"session_created", "status", "tool_use", "tool_result", "error"}:
return payload
if kind == "text":
payload.setdefault("role", "assistant")
return payload
if role in {"assistant", "model"} and isinstance(payload.get("content"), str):
return {"kind": "text", "role": "assistant", "content": payload["content"]}
if kind in {"assistant_message", "message"} and isinstance(payload.get("content"), str):
return {"kind": "text", "role": "assistant", "content": payload["content"]}
return None
def _normalize_text_line(self, line: str, transcript: str) -> dict | None:
if not line or line == transcript:
return None
lower = line.lower()
if lower in {
"working on it",
"working on it.",
"thinking",
"building a plan",
"starting multi-step work",
"complex task detected; switching to orchestrate mode",
"simple task detected; using tools",
}:
return {"kind": "status", "text": line}
if "error:" in lower or line.startswith("✗"):
return {"kind": "error", "content": line}
return {"kind": "text", "role": "assistant", "content": line}
def _normalize_stderr_line(self, line: str) -> dict | None:
lower = line.lower()
if (
lower.startswith("warning:")
or lower.startswith("warn ")
or lower.startswith("vision:")
or lower.startswith("info:")
):
return {"kind": "status", "text": line}
return {"kind": "error", "content": line}
def _build_voice_prompt(self, transcript: str) -> str:
preamble = settings.my_agent_voice_preamble.strip()
if not preamble:
return transcript
return f"{preamble}\n\n{transcript}"
def _strip_cli_formatting(self, text: str) -> str:
stripped = ANSI_ESCAPE_PATTERN.sub("", text)
stripped = stripped.replace("\x07", "").replace("\x08", "")
stripped = re.sub(r"\s+", " ", stripped)
return stripped.strip()
class SpeechPipeline:
def __init__(self) -> None:
self._whisper_error: str | None = None
self._tts_error: str | None = None
self._agent_cli = PersistentAgentCliSession(self.my_agent_chat_model)
self._model_lock = asyncio.Lock()
self._prefer_low_vram_gpu_swap = settings.whisper_device == "cuda" and self._tts_uses_cuda()
self._keep_tts_gpu_resident = self._prefer_low_vram_gpu_swap and settings.tts_gpu_resident_preferred
self._tts_runtime_logged = False
def preload_models(self) -> None:
if settings.stt_backend == "parakeet-tdt-v3" and not self._prefer_whisper_transcription():
_ = self.parakeet
else:
_ = self.whisper
_ = self.tts
if settings.assistant_backend == "hf-local":
_ = self.hf_local_generator
tts = self.__dict__.get("tts")
if tts is not None:
try:
self._generate_tts(tts, "Okay.", None)
except Exception:
pass
async def preload_assistant(self) -> None:
if settings.assistant_backend != "llama-server":
return
try:
async with httpx.AsyncClient(timeout=10.0) as client:
await client.get(f"{settings.llama_base_url.rstrip('/')[:-3]}/models")
payload = {
"model": settings.llama_model,
"messages": self._build_chat_messages(
settings.llama_system_prompt,
settings.llama_model,
"Say hello in two words.",
),
"max_tokens": 8,
"temperature": settings.llama_temperature,
"top_p": settings.llama_top_p,
"top_k": settings.llama_top_k,
"repeat_penalty": settings.llama_repetition_penalty,
"stop": [token.strip() for token in settings.llama_stop_tokens.split(",") if token.strip()],
"stream": False,
}
await client.post(
f"{settings.llama_base_url.rstrip('/')}/chat/completions",
headers={"Content-Type": "application/json"},
json=payload,
)
except Exception:
pass
def _clear_cached_model(self, *names: str) -> None:
cleared = False
for name in names:
if name in self.__dict__:
del self.__dict__[name]
cleared = True
if cleared:
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
def _drop_whisper(self) -> None:
self._clear_cached_model("whisper")
def _drop_tts(self) -> None:
self._clear_cached_model("tts", "backchannel_clips", "tts_prefill_clips")
def _ensure_whisper_ready(self):
if self._prefer_low_vram_gpu_swap and not self._keep_tts_gpu_resident:
self._drop_tts()
return self.whisper
def _ensure_tts_ready(self):
if self._prefer_low_vram_gpu_swap:
self._drop_whisper()
return self.tts
def _tts_uses_cuda(self) -> bool:
if settings.tts_backend == "chatterbox-onnx":
return settings.chatterbox_onnx_provider == "cuda"
if settings.tts_backend == "kokoro":
return settings.kokoro_device == "cuda"
return settings.chatterbox_device == "cuda"
@staticmethod
def _resolve_torch_device_static(requested_device: str, component: str) -> tuple[str, str | None]:
if requested_device != "cuda":
return requested_device, None
try:
import torch
except Exception as exc: # pragma: no cover
return "cpu", f"{component} fallback: torch CUDA probe failed ({exc}); using cpu"
if torch.cuda.is_available():
return requested_device, None
return "cpu", f"{component} fallback: CUDA requested but unavailable; using cpu"
@cached_property
def whisper(self):
from faster_whisper import WhisperModel
device, fallback_reason = self._resolve_torch_device(settings.whisper_device, component="Whisper")
if fallback_reason:
print(fallback_reason, file=sys.stderr)
candidates: list[str] = [settings.whisper_compute_type]
if device == "cuda":
for fallback in ("float16", "int8", "float32"):
if fallback not in candidates:
candidates.append(fallback)
else:
for fallback in ("int8", "float32"):
if fallback not in candidates:
candidates.append(fallback)
last_error: str | None = None
for compute_type in candidates:
try:
model = WhisperModel(
model_size_or_path=settings.whisper_model,
device=device,
compute_type=compute_type,
)
if device != settings.whisper_device:
print(
f"Whisper fallback: using device={device} instead of {settings.whisper_device}",
file=sys.stderr,
)
if compute_type != settings.whisper_compute_type:
print(
f"Whisper fallback: using compute_type={compute_type} instead of {settings.whisper_compute_type}",
file=sys.stderr,
)
self._whisper_error = None
return model
except Exception as exc: # pragma: no cover
last_error = str(exc)
self._whisper_error = last_error
return None
@cached_property
def parakeet(self):
print(
"STT load backend=parakeet-tdt-v3 "
f"device={settings.parakeet_device} "
f"model={settings.parakeet_model_id}",
file=sys.stderr,
)
try:
nemo_asr = importlib.import_module("nemo.collections.asr")
model = nemo_asr.models.ASRModel.from_pretrained(model_name=settings.parakeet_model_id)
device, fallback_reason = self._resolve_torch_device(settings.parakeet_device, component="Parakeet")
if fallback_reason:
print(fallback_reason, file=sys.stderr)
try:
model = model.to(device)
except Exception as exc:
if device != "cpu":
print(
f"Parakeet fallback: failed to move to {device} ({exc}); using cpu",
file=sys.stderr,
)
model = model.to("cpu")
device = "cpu"
else:
raise
if device != settings.parakeet_device:
print(
f"Parakeet fallback: using device={device} instead of {settings.parakeet_device}",
file=sys.stderr,
)
self._whisper_error = None
return model
except Exception as exc: # pragma: no cover
self._whisper_error = str(exc)
print(f"Parakeet load failure: {exc}", file=sys.stderr)
return None
@cached_property
def tts(self):
try:
if settings.tts_backend == "kokoro":
print(
"TTS load backend=kokoro "
f"device={settings.kokoro_device} "
f"voice={settings.kokoro_voice}",
file=sys.stderr,
)
return KokoroTTS()
if settings.tts_backend == "chatterbox-onnx":
tts = ChatterboxOnnxTTS()
print(
"TTS load backend=chatterbox-onnx "
f"provider={settings.chatterbox_onnx_provider} "
f"dtype={settings.chatterbox_onnx_dtype}",
file=sys.stderr,
)
return tts
from chatterbox.tts_turbo import ChatterboxTurboTTS
device, fallback_reason = self._resolve_torch_device(settings.chatterbox_device, component="Chatterbox")
if fallback_reason:
print(fallback_reason, file=sys.stderr)
if device != settings.chatterbox_device:
print(
f"Chatterbox fallback: using device={device} instead of {settings.chatterbox_device}",
file=sys.stderr,
)
print(
f"TTS load backend=chatterbox-pytorch device={device}",
file=sys.stderr,
)
return ChatterboxTurboTTS.from_pretrained(device=device)
except Exception as exc: # pragma: no cover
self._tts_error = str(exc)
if settings.tts_backend == "chatterbox-onnx":
try:
from chatterbox.tts_turbo import ChatterboxTurboTTS
device, fallback_reason = self._resolve_torch_device(settings.chatterbox_device, component="Chatterbox")
if fallback_reason:
print(fallback_reason, file=sys.stderr)
if device != settings.chatterbox_device:
print(
f"Chatterbox fallback: using device={device} instead of {settings.chatterbox_device}",
file=sys.stderr,
)
print(
f"TTS load backend=chatterbox-pytorch fallback_from=onnx device={device}",
file=sys.stderr,
)
fallback = ChatterboxTurboTTS.from_pretrained(device=device)
self._tts_error = f"ONNX backend failed, fell back to PyTorch: {exc}"
return fallback
except Exception:
pass
return None
@cached_property
def backchannel_clips(self) -> dict[str, np.ndarray]:
tts = self.tts
if tts is None:
return {}
clips: dict[str, np.ndarray] = {}
for text in ("mm-hmm", "yeah", "right"):
try:
clips[text] = self._generate_tts(tts, text, None)
except Exception:
continue
return clips
@cached_property
def tts_prefill_clips(self) -> dict[str, np.ndarray]:
tts = self.tts
if tts is None or not settings.tts_prefill_enabled:
return {}
clips: dict[str, np.ndarray] = {}
for text in (item.strip() for item in settings.tts_prefill_choices.split(",")):
if not text:
continue
try:
clips[text.lower()] = self._generate_tts(tts, text, None)
except Exception:
continue
return clips
@cached_property
def hf_local_generator(self):
transformers = importlib.import_module("transformers")
torch = importlib.import_module("torch")
tokenizer = transformers.AutoTokenizer.from_pretrained(settings.hf_local_model_id)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
dtype_name = settings.hf_local_dtype.lower().strip()
dtype_map = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
torch_dtype = dtype_map.get(dtype_name, torch.float32)
device, fallback_reason = self._resolve_torch_device(settings.hf_local_device, component="HF local LLM")
if fallback_reason:
print(fallback_reason, file=sys.stderr)
model_kwargs: dict[str, object] = {"torch_dtype": torch_dtype}
if device == "cuda":
model_kwargs["device_map"] = "auto"
model = transformers.AutoModelForCausalLM.from_pretrained(
settings.hf_local_model_id,
low_cpu_mem_usage=True,
**model_kwargs,
)
if device == "cpu":
model.to("cpu")
model.eval()
if not settings.hf_local_do_sample:
try:
model.generation_config.do_sample = False
model.generation_config.temperature = None
model.generation_config.top_p = None
model.generation_config.top_k = None
except Exception:
pass
return {
"torch": torch,
"tokenizer": tokenizer,
"model": model,
"device": device,
"model_id": settings.hf_local_model_id,
}
@cached_property
def my_agent_chat_model(self) -> str | None:
if settings.my_agent_model:
return settings.my_agent_model
try:
result = subprocess.run(
[self._agent_cli._resolve_my_agent_command(), "config", "--get-model", "chat"],
capture_output=True,
check=True,
text=True,
)
except Exception:
return None
match = re.search(r"Model for 'chat':\s*(.+)", result.stdout)
if not match:
return None
model = match.group(1).strip()
return model or None
def assistant_backend_metadata(self) -> tuple[str, str]:
if settings.assistant_backend == "my-agent-cli":
model = self.my_agent_chat_model or "cli default"
return settings.assistant_backend, model
if settings.assistant_backend == "hf-local":
return settings.assistant_backend, settings.hf_local_model_id
if settings.assistant_backend == "llama-server":
return settings.assistant_backend, settings.llama_model
return settings.assistant_backend, settings.openrouter_model
def reset_assistant_session(self) -> None:
self._agent_cli.reset_session()
async def transcribe(self, audio: np.ndarray) -> TranscriptionResult:
if audio.size == 0:
return TranscriptionResult(text="", backend="none")
async with self._model_lock:
if settings.stt_backend == "parakeet-tdt-v3" and not self._prefer_whisper_transcription():
parakeet = self.parakeet
if parakeet is None:
return TranscriptionResult(
text=f"[parakeet unavailable] {self._whisper_error or 'model failed to load'}",
backend="parakeet",
)
return await asyncio.to_thread(self._run_parakeet_transcription, parakeet, audio)
whisper = self._ensure_whisper_ready()
if whisper is None:
return TranscriptionResult(
text=f"[whisper unavailable] {self._whisper_error or 'model failed to load'}",
backend="whisper",
)
return await asyncio.to_thread(
self._run_transcription,
whisper,
audio,
settings.whisper_beam_size,
settings.whisper_best_of,
settings.whisper_log_prob_threshold,
settings.whisper_no_speech_threshold,
)
async def transcribe_fallback(self, audio: np.ndarray) -> TranscriptionResult:
if audio.size == 0:
return TranscriptionResult(text="", backend="none")
async with self._model_lock:
if settings.stt_backend == "parakeet-tdt-v3" and not self._prefer_whisper_transcription():
parakeet = self.parakeet
if parakeet is None:
return TranscriptionResult(text="", backend="parakeet")
return await asyncio.to_thread(self._run_parakeet_transcription, parakeet, audio)
whisper = self._ensure_whisper_ready()
if whisper is None:
return TranscriptionResult(text="", backend="whisper")
return await asyncio.to_thread(
self._run_transcription,
whisper,
audio,
settings.whisper_fallback_beam_size,
settings.whisper_fallback_best_of,
settings.whisper_fallback_log_prob_threshold,
settings.whisper_fallback_no_speech_threshold,
)
async def transcribe_partial(self, audio: np.ndarray) -> str:
if audio.size == 0:
return ""
async with self._model_lock:
if settings.stt_backend == "parakeet-tdt-v3" and not self._prefer_whisper_transcription():
return ""
whisper = self._ensure_whisper_ready()
if whisper is None:
return ""
def _run() -> str:
prepared_audio = self._prepare_audio_for_whisper(audio)
language = self._resolved_stt_language()
segments, _ = whisper.transcribe(
prepared_audio,
language=language,
vad_filter=False,
beam_size=1,
best_of=1,
temperature=0.0,
condition_on_previous_text=False,
without_timestamps=True,
log_prob_threshold=settings.whisper_log_prob_threshold,
no_speech_threshold=settings.whisper_no_speech_threshold,
)
transcript = " ".join(segment.text.strip() for segment in segments).strip()
return self._normalize_transcript(transcript)
return await asyncio.to_thread(_run)
async def generate_reply(self, transcript: str) -> str:
sentences = []
async for sentence in self.stream_reply_sentences(transcript):
sentences.append(sentence)
if not sentences:
cleaned = transcript.strip()
return "I didn't catch that." if not cleaned else "Sorry, give me a second."
return " ".join(sentences).strip()
async def stream_reply_sentences(
self,
transcript: str,
*,
conversation_history: list[dict[str, str]] | None = None,
response_language: str | None = None,
) -> AsyncIterator[str]:
cleaned = transcript.strip()
if not cleaned:
yield "I didn't catch that."
return
try:
if settings.assistant_backend == "hf-local":
async for sentence in self._stream_hf_local_sentences(
cleaned,
conversation_history=conversation_history,
response_language=response_language,
):
yield sentence
return
if settings.assistant_backend == "llama-server":
async for sentence in self._stream_openai_compatible_sentences(
transcript=cleaned,
conversation_history=conversation_history,
response_language=response_language,
base_url=settings.llama_base_url,
api_key=settings.llama_api_key,
model=settings.llama_model,
system_prompt=settings.llama_system_prompt,
max_tokens=settings.llama_max_tokens,
temperature=settings.llama_temperature,
top_p=settings.llama_top_p,
top_k=settings.llama_top_k,
repetition_penalty=settings.llama_repetition_penalty,
stop_tokens=[token.strip() for token in settings.llama_stop_tokens.split(",") if token.strip()],
):
yield sentence
return
api_key = os.getenv("OPENROUTER_API_KEY") or settings.openrouter_api_key
if api_key and api_key.startswith("$OPENROUTER_API_KEY"):
api_key = None
if not api_key:
yield "Sorry, give me a second."
return
headers = {
"Authorization": f"Bearer {api_key}",
"HTTP-Referer": settings.openrouter_site_url,
"X-Title": settings.openrouter_app_name,
}
async for sentence in self._stream_openai_compatible_sentences(
transcript=cleaned,
conversation_history=conversation_history,
response_language=response_language,
base_url=settings.openrouter_base_url,
api_key=api_key,
model=settings.openrouter_model,
system_prompt=settings.openrouter_system_prompt,
max_tokens=settings.openrouter_max_tokens,
temperature=settings.openrouter_temperature,
top_p=1.0,
top_k=0,
repetition_penalty=1.0,
stop_tokens=[],
extra_headers=headers,
):
yield sentence
except Exception as exc:
print(f"Assistant backend failure backend={settings.assistant_backend}: {exc}", file=sys.stderr)
yield "Sorry, give me a second."
async def _stream_hf_local_sentences(
self,
transcript: str,
*,
conversation_history: list[dict[str, str]] | None = None,
response_language: str | None = None,
) -> AsyncIterator[str]:
text = await asyncio.to_thread(
self._generate_hf_local_reply,
transcript,
conversation_history,
response_language,
)
if not text:
return
buffer = text.strip()
sentences, remainder = self._split_complete_sentences(buffer)
for sentence in sentences:
yield sentence
if remainder.strip():
yield remainder.strip()
def _generate_hf_local_reply(
self,
transcript: str,
conversation_history: list[dict[str, str]] | None = None,
response_language: str | None = None,
) -> str:
runtime = self.hf_local_generator
torch = runtime["torch"]
tokenizer = runtime["tokenizer"]
model = runtime["model"]
device = runtime["device"]
encoded = self._encode_hf_local_prompt(
tokenizer,
transcript,
conversation_history=conversation_history,
response_language=response_language,
)
tokenized = encoded["input_ids"]
attention_mask = encoded["attention_mask"]
if device == "cuda":
tokenized = tokenized.to("cuda")
attention_mask = attention_mask.to("cuda")
generate_kwargs = {
"input_ids": tokenized,
"attention_mask": attention_mask,
"max_new_tokens": settings.hf_local_max_new_tokens,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
if settings.hf_local_do_sample:
generate_kwargs.update(
{
"do_sample": True,
"temperature": settings.hf_local_temperature,
"top_p": settings.hf_local_top_p,
}
)
else:
generate_kwargs["do_sample"] = False
with torch.inference_mode():
generated = model.generate(**generate_kwargs)
new_tokens = generated[:, tokenized.shape[-1]:]
text = tokenizer.decode(new_tokens[0], skip_special_tokens=True)
normalized = self._normalize_hf_local_reply(text)
print(f"HF local raw reply={text!r}", file=sys.stderr)
print(f"HF local normalized reply={normalized!r}", file=sys.stderr)
return normalized
def _encode_hf_local_prompt(
self,
tokenizer,
transcript: str,
*,
conversation_history: list[dict[str, str]] | None = None,
response_language: str | None = None,
):
prefix = self._augment_system_prompt(settings.hf_local_prompt_prefix, response_language)
messages = []
if prefix:
messages.append({"role": "system", "content": prefix})
for message in conversation_history or []:
role = (message.get("role") or "").strip().lower()
content = (message.get("content") or "").strip()
if role in {"user", "assistant"} and content:
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": transcript})
if hasattr(tokenizer, "apply_chat_template"):
try:
return tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
)
except Exception:
pass
prompt = self._build_hf_local_prompt(
transcript,
conversation_history=conversation_history,
response_language=response_language,
)
return tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024,
padding=False,
)
async def _stream_openai_compatible_sentences(
self,
*,
transcript: str,
conversation_history: list[dict[str, str]] | None,
response_language: str | None,
base_url: str,
api_key: str | None,
model: str,
system_prompt: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
stop_tokens: list[str],
extra_headers: dict[str, str] | None = None,
) -> AsyncIterator[str]:
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if extra_headers:
headers.update(extra_headers)
payload = {
"model": model,
"messages": self._build_chat_messages(
system_prompt,
model,
transcript,
conversation_history=conversation_history,
response_language=response_language,
),
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repeat_penalty": repetition_penalty,
"stream": True,
}
if stop_tokens:
payload["stop"] = stop_tokens
async with httpx.AsyncClient(timeout=20.0) as client:
async with client.stream(
"POST",
f"{base_url.rstrip('/')}/chat/completions",
headers=headers,
json=payload,
) as response:
response.raise_for_status()
buffer = ""
async for raw_line in response.aiter_lines():
line = raw_line.strip()
if not line.startswith("data:"):
continue
data = line[5:].strip()
if data == "[DONE]":
break
try:
event = json.loads(data)
except json.JSONDecodeError:
continue
delta = self._extract_stream_delta(event)
if not delta:
continue
buffer += delta
sentences, buffer = self._split_complete_sentences(buffer)
for sentence in sentences:
yield sentence
if buffer.strip():
yield buffer.strip()
async def stream_agent_cli_events(self, transcript: str, session_id: str | None = None) -> AsyncIterator[dict]:
async for event in self._agent_cli.stream_events(transcript, session_id):
yield event
def _build_chat_messages(
self,
system_prompt: str,
model: str,
transcript: str,
*,
conversation_history: list[dict[str, str]] | None = None,
response_language: str | None = None,
) -> list[dict[str, str]]:
normalized_prompt = self._augment_system_prompt(system_prompt, response_language)
model_name = model.lower()
if "gemma" in model_name:
prompt_parts = [normalized_prompt]
for message in conversation_history or []:
role = (message.get("role") or "").strip().lower()
content = (message.get("content") or "").strip()
if role == "user" and content:
prompt_parts.append(f"User: {content}")
elif role == "assistant" and content:
prompt_parts.append(f"Assistant: {content}")
prompt_parts.append(f"User: {transcript}")
prompt_parts.append("Assistant:")
prompt = "\n\n".join(prompt_parts)
return [{"role": "user", "content": prompt}]
messages = [{"role": "system", "content": normalized_prompt}]
for message in conversation_history or []:
role = (message.get("role") or "").strip().lower()
content = (message.get("content") or "").strip()
if role in {"user", "assistant"} and content:
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": transcript})
return messages
def _build_hf_local_prompt(
self,
transcript: str,
*,
conversation_history: list[dict[str, str]] | None = None,
response_language: str | None = None,
) -> str:
prefix = self._augment_system_prompt(settings.hf_local_prompt_prefix, response_language)
if not prefix:
return transcript
prompt_parts = [prefix]
for message in conversation_history or []:
role = (message.get("role") or "").strip().lower()
content = (message.get("content") or "").strip()
if role == "user" and content:
prompt_parts.append(f"User: {content}")
elif role == "assistant" and content:
prompt_parts.append(f"Assistant: {content}")
prompt_parts.append(f"User: {transcript}")
prompt_parts.append("Assistant:")
return "\n\n".join(prompt_parts)
def _augment_system_prompt(self, system_prompt: str, response_language: str | None) -> str:
prompt = system_prompt.strip()
language_hint = self._language_instruction(response_language)
if not language_hint:
return prompt
if not prompt:
return language_hint
return f"{prompt} {language_hint}"
def _language_instruction(self, language_code: str | None) -> str:
if not language_code or language_code == "auto":
return (
"By default, reply in the same language as the user's latest message. "
"If the user explicitly asks for translation, comparison, or output in two or more languages, include all requested languages."
)
language_names = {
"en": "English",
"es": "Spanish",
"fr": "French",
"de": "German",
"it": "Italian",
"pt": "Portuguese",
"nl": "Dutch",
"ru": "Russian",
"uk": "Ukrainian",
"pl": "Polish",
"tr": "Turkish",
"ar": "Arabic",
"hi": "Hindi",
"ja": "Japanese",
"ko": "Korean",
"zh": "Chinese",
}
label = language_names.get(language_code.lower(), language_code)
return (
f"By default, reply in {label}. "
"If the user explicitly asks for translation, comparison, or output in two or more languages, include all requested languages. "
"Keep the answer natural and concise."
)
async def stream_synthesized_chunks(self, text: str, voice_prompt_path: str | None = None) -> AsyncIterator[np.ndarray]:
chunks = self._split_tts_chunks(text)
if not chunks:
chunks = [text]
async with self._model_lock:
tts = self._ensure_tts_ready()
if tts is None:
yield self._beep()
return
for chunk in chunks:
audio = await asyncio.to_thread(self._generate_tts, tts, chunk, voice_prompt_path)
yield audio
async def synthesize_sentences(self, text: str, voice_prompt_path: str | None = None) -> list[np.ndarray]:
return [chunk async for chunk in self.stream_synthesized_chunks(text, voice_prompt_path=voice_prompt_path)]
async def stream_reply_audio(self, text: str, transcript: str, voice_prompt_path: str | None = None) -> AsyncIterator[np.ndarray]:
if (
not self._prefer_low_vram_gpu_swap
and self.tts is not None
and voice_prompt_path is None
and settings.tts_prefill_enabled
and len(text.strip()) >= settings.tts_prefill_min_chars
):
prefill = self._choose_tts_prefill(text, transcript)
if prefill:
cached = self.tts_prefill_clips.get(prefill.lower())
if cached is not None:
yield cached.copy()
async for chunk in self.stream_synthesized_chunks(text, voice_prompt_path=voice_prompt_path):
yield chunk
async def synthesize_reply(self, text: str, transcript: str, voice_prompt_path: str | None = None) -> list[np.ndarray]:
return [
chunk
async for chunk in self.stream_reply_audio(text, transcript, voice_prompt_path=voice_prompt_path)
]
async def synthesize_backchannel(self, text: str, voice_prompt_path: str | None = None) -> list[np.ndarray]:
if not self._prefer_low_vram_gpu_swap and not voice_prompt_path:
normalized = text.strip().lower()
cached = self.backchannel_clips.get(normalized)
if cached is not None:
return [cached.copy()]
return await self.synthesize_sentences(text, voice_prompt_path=voice_prompt_path)
def _generate_tts(self, tts, text: str, voice_prompt_path: str | None) -> np.ndarray:
if not self._tts_runtime_logged:
backend = type(tts).__name__
details = [f"class={backend}"]
provider = getattr(tts, "_provider", None)
dtype = getattr(tts, "_dtype", None)
model_id = getattr(tts, "_model_id", None)
if provider:
details.append(f"provider={provider}")
if dtype:
details.append(f"dtype={dtype}")
if model_id:
details.append(f"model={model_id}")
print(f"TTS generate {' '.join(details)}", file=sys.stderr)
self._tts_runtime_logged = True
kwargs = {"audio_prompt_path": voice_prompt_path} if voice_prompt_path else {}
wav = tts.generate(text, **kwargs)
if hasattr(wav, "detach"):
wav = wav.detach().cpu().numpy()
return np.asarray(wav, dtype=np.float32).flatten()
def _beep(self) -> np.ndarray:
duration_s = 0.25
sample_rate = 24000
time = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False)
return (0.15 * np.sin(2 * np.pi * 660 * time)).astype(np.float32)
def _extract_stream_delta(self, event: dict) -> str:
choices = event.get("choices") or []
if not choices:
return ""
delta = choices[0].get("delta") or {}
content = delta.get("content") or ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(item.get("text") or "")
return "".join(parts)
return ""
def render_assistant_text(self, text: str, transcript: str) -> tuple[str, str]:
tts_text = self._prepare_tts_text(text, transcript)
display_text = self._strip_paralinguistic_tags(tts_text)
return display_text, tts_text
def _split_complete_sentences(self, buffer: str) -> tuple[list[str], str]:
chunks: list[str] = []
remainder = buffer.lstrip()
while remainder:
end_index = self._next_stream_boundary_end(remainder)
if end_index is None:
break
chunk = remainder[:end_index].strip()
if chunk:
chunks.append(chunk)
remainder = remainder[end_index:].lstrip()
return chunks, remainder
def _next_stream_boundary_end(self, text: str) -> int | None:
split_match = self._find_chunk_boundary(text)
if split_match is not None:
return split_match.end()
return None
def _find_chunk_boundary(self, text: str) -> re.Match[str] | None:
clause_match = re.search(r".{18,}?[,:;](?:\s+|$)", text, flags=re.S)
soft_clause_match = re.search(
r".{22,}?\b(?:and|but|so|because|then|while|which)\b(?:\s+|$)",
text,
flags=re.S | re.I,
)
sentence_match = re.search(r".*?[.!?](?:\s+|$)", text, flags=re.S)
clause_match = clause_match or soft_clause_match
if clause_match is None:
return sentence_match
if sentence_match is None:
return clause_match
return clause_match if clause_match.end() <= sentence_match.end() else sentence_match
def _find_early_stream_boundary(self, text: str) -> int | None:
if len(text) < settings.assistant_stream_chunk_min_chars:
return None
max_chars = max(settings.assistant_stream_chunk_min_chars, settings.assistant_stream_chunk_max_chars)
search_limit = min(len(text), max_chars)
dangling_words = {
"a",
"an",
"the",
"and",
"or",
"but",
"so",
"because",
"then",
"while",
"which",
"who",
"what",
"when",
"where",
"why",
"how",
"if",
"that",
"this",
"these",
"those",
"to",
"of",
"for",
"with",
"at",
"from",
"in",
"on",
"is",
"are",
"was",
"were",
"be",
"been",
"being",
"do",
"does",
"did",
"can",
"could",
"should",
"would",
"will",
"have",
"has",
"had",
}
boundary = None
for match in re.finditer(r"\s+", text[:search_limit]):
boundary = match.end()
if boundary is None:
return None
candidate = text[:boundary].strip(" \t\r\n,;:")
if len(candidate) < settings.assistant_stream_chunk_min_chars:
return None
words = re.findall(r"\b[\w'-]+\b", candidate)
if len(words) < settings.assistant_stream_chunk_min_words:
return None
if words[-1].lower() in dangling_words:
return None
return boundary
def _prepare_tts_text(self, text: str, transcript: str) -> str:
cleaned = self._normalize_reply_text(text)
if not cleaned:
return ""
cleaned = self._normalize_existing_tags(cleaned)
transcript_word_count = len(re.findall(r"\w+", transcript))
if (
settings.tts_auto_ack_prefix_enabled
and transcript_word_count >= 10
and not re.match(r"(?i)^(yeah|right|mm-hmm|got it|okay|ok|sure)\b", cleaned)
):
if len(cleaned) > 1:
cleaned = f"Mm-hmm, {cleaned[0].lower() + cleaned[1:]}"
else:
cleaned = f"Mm-hmm, {cleaned.lower()}"
return cleaned
def _split_tts_chunks(self, text: str) -> list[str]:
cleaned = text.strip()
if not cleaned:
return []
return [cleaned]
def _extract_short_lead_chunk(self, text: str) -> str:
words = text.split()
if len(words) <= settings.tts_first_chunk_max_words and len(text) <= settings.tts_first_chunk_max_chars:
return text
clause_match = re.search(rf"^(.{{1,{settings.tts_first_chunk_max_chars}}}[,;:])(?=\s|$)", text)
if clause_match:
candidate = clause_match.group(1).strip()
if len(candidate.split()) <= settings.tts_first_chunk_max_words + 2:
return candidate
candidate = " ".join(words[: settings.tts_first_chunk_max_words]).strip()
return re.sub(r"[,:;]+$", "", candidate).strip() or text
def _choose_tts_prefill(self, text: str, transcript: str) -> str:
if re.match(r"(?i)^(okay|yeah|right|got it|mm-hmm|uh-huh)\b", text.strip()):
return ""
choices = [item.strip() for item in settings.tts_prefill_choices.split(",") if item.strip()]
if not choices:
return ""
seed = f"{transcript.strip().lower()}::{text.strip().lower()}"
return choices[abs(hash(seed)) % len(choices)]
def _normalize_reply_text(self, text: str) -> str:
cleaned = text
for pattern in INTERNAL_REPLY_PATTERNS:
cleaned = pattern.sub(" ", cleaned)
if settings.hf_local_hide_thinking:
cleaned = THINK_BLOCK_PATTERN.sub(" ", cleaned)
cleaned = THINK_TAG_PATTERN.sub(" ", cleaned)
cleaned = re.sub(r"\s+", " ", cleaned).strip()
cleaned = re.sub(r"^\s*[,.:;!?-]+\s*", "", cleaned)
cleaned = re.sub(r"\s+([,.:;!?])", r"\1", cleaned)
return cleaned
def _normalize_hf_local_reply(self, text: str) -> str:
cleaned = self._normalize_reply_text(text)
cleaned = re.sub(r"(?i)^assistant:\s*", "", cleaned)
cleaned = re.sub(r"(?i)\buser:\s*.*$", "", cleaned).strip()
role_match = ROLE_CONTINUATION_PATTERN.search(cleaned)
if role_match:
cleaned = cleaned[: role_match.start()].strip()
cleaned = META_TAIL_PATTERN.sub("", cleaned).strip()
cleaned = re.sub(r"[🙂-🙏🤖😊😂🤣😍😘😉😎😄😁😃]+", "", cleaned)
cleaned = self._dedupe_repeated_reply(cleaned)
cleaned = self._limit_spoken_sentences(cleaned, max_sentences=2)
return cleaned
def _dedupe_repeated_reply(self, text: str) -> str:
cleaned = text.strip()
if not cleaned:
return cleaned
if len(cleaned) % 2 == 0:
half = len(cleaned) // 2
left = cleaned[:half].strip(" \t\r\n,.;:!?")
right = cleaned[half:].strip(" \t\r\n,.;:!?")
if left and left.lower() == right.lower():
return left
sentence_parts = re.split(r"(?<=[.!?])\s+", cleaned)
deduped: list[str] = []
for part in sentence_parts:
normalized = part.strip()
if not normalized:
continue
if deduped and deduped[-1].strip().lower() == normalized.lower():
continue
if normalized.endswith("?"):
if any(existing.strip().lower() == normalized.lower() for existing in deduped):
continue
deduped.append(normalized)
return " ".join(deduped).strip()
def _limit_spoken_sentences(self, text: str, max_sentences: int) -> str:
cleaned = text.strip()
if not cleaned or max_sentences <= 0:
return cleaned
parts = re.split(r"(?<=[.!?])\s+", cleaned)
kept: list[str] = []
for part in parts:
normalized = part.strip()
if not normalized:
continue
kept.append(normalized)
if len(kept) >= max_sentences:
break
if kept:
return " ".join(kept).strip()
return cleaned
def _normalize_existing_tags(self, text: str) -> str:
return PARALINGUISTIC_TAG_PATTERN.sub(lambda match: match.group(0).lower(), text)
def _strip_paralinguistic_tags(self, text: str) -> str:
stripped = PARALINGUISTIC_TAG_PATTERN.sub("", text)
stripped = re.sub(r"\s+", " ", stripped)
stripped = re.sub(r"\s+([,.:;!?])", r"\1", stripped)
return stripped.strip()
def _normalize_transcript(self, transcript: str) -> str:
cleaned = transcript.strip()
normalized = re.sub(r"[^\w\s']", "", cleaned.lower())
normalized = re.sub(r"\s+", " ", normalized).strip()
if normalized in {
"",
"thank you for watching",
"thanks for watching",
}:
return ""
return cleaned
def is_likely_hallucination(self, transcript: str, audio_rms: float) -> bool:
normalized = re.sub(r"[^\w\s']", "", transcript.lower())
normalized = re.sub(r"\s+", " ", normalized).strip()
if not normalized:
return False
if audio_rms > settings.hallucination_max_rms:
return False
words = [word for word in normalized.split(" ") if word]
if len(words) > settings.hallucination_max_words:
return False
blocked = {phrase.strip() for phrase in settings.hallucination_phrases.split(",") if phrase.strip()}
return normalized in blocked
def _resolve_torch_device(self, requested_device: str, component: str) -> tuple[str, str | None]:
return self._resolve_torch_device_static(requested_device, component)
def _resolved_stt_language(self) -> str | None:
language = settings.stt_language.strip().lower()
if not language or language == "auto":
return None
return language
def _whisper_model_supports_multilingual(self) -> bool:
return not settings.whisper_model.strip().lower().endswith(".en")
def _prefer_whisper_transcription(self) -> bool:
if settings.stt_backend != "parakeet-tdt-v3":
return True
if not settings.stt_multilingual_enabled:
return False
return self._whisper_model_supports_multilingual()
def _prepare_audio_for_whisper(self, audio: np.ndarray) -> np.ndarray:
duration_ms = (audio.size / settings.sample_rate) * 1000.0
if duration_ms > settings.short_utterance_ms:
return audio
pad_samples = int(settings.sample_rate * (settings.short_utterance_pad_ms / 1000.0))
prepared = np.pad(audio, (pad_samples, pad_samples), mode="constant")
min_samples = int(settings.sample_rate * (settings.short_utterance_min_transcription_ms / 1000.0))
if prepared.size >= min_samples:
return prepared.astype(np.float32, copy=False)
extra = min_samples - prepared.size
left = extra // 2
right = extra - left
return np.pad(prepared, (left, right), mode="constant").astype(np.float32, copy=False)
def _run_transcription(
self,
whisper,
audio: np.ndarray,
beam_size: int,
best_of: int,
log_prob_threshold: float | None,
no_speech_threshold: float | None,
) -> TranscriptionResult:
prepared_audio = self._prepare_audio_for_whisper(audio)
language = self._resolved_stt_language()
kwargs = {
"vad_filter": False,
"beam_size": beam_size,
"best_of": best_of,
"temperature": settings.whisper_temperature,
"condition_on_previous_text": settings.whisper_condition_on_previous_text,
"without_timestamps": True,
}
if language is not None:
kwargs["language"] = language
if log_prob_threshold is not None:
kwargs["log_prob_threshold"] = log_prob_threshold
if no_speech_threshold is not None:
kwargs["no_speech_threshold"] = no_speech_threshold
segments, info = whisper.transcribe(prepared_audio, **kwargs)
transcript = " ".join(segment.text.strip() for segment in segments).strip()
return TranscriptionResult(
text=self._normalize_transcript(transcript),
language=getattr(info, "language", None),
language_probability=getattr(info, "language_probability", None),
backend="whisper",
)
def _run_parakeet_transcription(self, parakeet, audio: np.ndarray) -> TranscriptionResult:
prepared_audio = self._prepare_audio_for_whisper(audio)
wav_bytes = wav_bytes_from_float32(prepared_audio, settings.sample_rate)
temp_path: str | None = None
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as handle:
handle.write(wav_bytes)
temp_path = handle.name
results = parakeet.transcribe([temp_path])
if not results:
return TranscriptionResult(text="", language="en", backend="parakeet")
first = results[0]
transcript = getattr(first, "text", None)
if transcript is None:
transcript = str(first)
return TranscriptionResult(
text=self._normalize_transcript(str(transcript).strip()),
language="en",
backend="parakeet",
)
finally:
if temp_path:
try:
os.unlink(temp_path)
except FileNotFoundError:
pass
pipeline = SpeechPipeline()