File size: 4,965 Bytes
31e5b3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
voice_interface.py — Voice I/O for the Computer Agent
======================================================
Speech-to-Text (Whisper / Faster-Whisper) and TTS (HF Inference API)
"""

import os
import io
import tempfile
import base64
from typing import Optional, Dict, Any

import numpy as np

# STT
try:
    from faster_whisper import WhisperModel
    HAS_FASTER_WHISPER = True
except ImportError:
    HAS_FASTER_WHISPER = False

# TTS via HF Inference
try:
    from huggingface_hub import InferenceClient
    HAS_HF_INFERENCE = True
except ImportError:
    HAS_HF_INFERENCE = False


class VoiceInterface:
    """Handles audio input (STT) and output (TTS) for the agent."""

    def __init__(
        self,
        stt_model_size: str = "base",
        tts_model: str = "hexgrad/Kokoro-82M",
        hf_token: Optional[str] = None,
    ):
        self.stt_model_size = stt_model_size
        self.tts_model = tts_model
        self.hf_token = hf_token or os.getenv("HF_TOKEN")
        self._stt: Optional[Any] = None
        self._tts_client: Optional[Any] = None

    # ------------------------------------------------------------------
    # STT
    # ------------------------------------------------------------------

    def _load_stt(self) -> Any:
        if self._stt is None:
            if HAS_FASTER_WHISPER:
                # Use CPU for Spaces compatibility; auto-detect compute type
                self._stt = WhisperModel(self.stt_model_size, device="cpu", compute_type="int8")
            else:
                raise RuntimeError("faster-whisper not installed. Run: pip install faster-whisper")
        return self._stt

    def transcribe(self, audio_np: np.ndarray, sample_rate: int = 16000) -> Dict[str, Any]:
        """Transcribe audio waveform to text.
        audio_np: numpy array of float32 audio samples
        """
        model = self._load_stt()
        # faster-whisper expects a file path or bytes; save to temp wav
        import soundfile as sf
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
            sf.write(f.name, audio_np, sample_rate)
            segments, info = model.transcribe(f.name, beam_size=5)
            text = " ".join([seg.text for seg in segments])
            os.unlink(f.name)
        return {
            "text": text.strip(),
            "language": info.language,
            "probability": info.language_probability,
        }

    def transcribe_from_file(self, file_path: str) -> Dict[str, Any]:
        model = self._load_stt()
        segments, info = model.transcribe(file_path, beam_size=5)
        text = " ".join([seg.text for seg in segments])
        return {
            "text": text.strip(),
            "language": info.language,
            "probability": info.language_probability,
        }

    # ------------------------------------------------------------------
    # TTS
    # ------------------------------------------------------------------

    def _load_tts(self) -> Any:
        if self._tts_client is None:
            if HAS_HF_INFERENCE:
                self._tts_client = InferenceClient(model=self.tts_model, token=self.hf_token)
            else:
                raise RuntimeError("huggingface_hub not installed")
        return self._tts_client

    def synthesize(self, text: str, voice: str = "af") -> bytes:
        """Synthesize text to speech bytes.
        Returns raw audio bytes (usually WAV or MP3 depending on model).
        """
        client = self._load_tts()
        try:
            audio = client.text_to_speech(text, model=self.tts_model)
            if hasattr(audio, "read"):
                return audio.read()
            return audio
        except Exception as e:
            # Fallback to standard TTS endpoint
            alt_client = InferenceClient(token=self.hf_token)
            audio = alt_client.text_to_speech(text, model="espnet/kan-bayashi_ljspeech_vits")
            if hasattr(audio, "read"):
                return audio.read()
            return audio

    def synthesize_to_file(self, text: str, output_path: str, voice: str = "af") -> str:
        audio_bytes = self.synthesize(text, voice)
        with open(output_path, "wb") as f:
            f.write(audio_bytes)
        return output_path

    # ------------------------------------------------------------------
    # Gradio helpers
    # ------------------------------------------------------------------

    def process_gradio_audio(self, audio_tuple) -> str:
        """Process Gradio audio input (tuple of sample_rate, numpy_array)."""
        if audio_tuple is None:
            return ""
        sample_rate, audio_np = audio_tuple
        # Convert to mono float32 if needed
        if audio_np.ndim > 1:
            audio_np = audio_np.mean(axis=1)
        if audio_np.dtype != np.float32:
            audio_np = audio_np.astype(np.float32)
        result = self.transcribe(audio_np, sample_rate)
        return result["text"]