File size: 9,223 Bytes
e568430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97d41ab
e568430
 
 
 
 
97d41ab
e568430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71ca2eb
e568430
 
 
71ca2eb
e568430
 
 
fa36a89
e568430
 
 
fa36a89
e568430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71ca2eb
e568430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71ca2eb
e568430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71ca2eb
e568430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""Speech-to-Text service using Gradio Client API."""

import asyncio
import tempfile
from functools import lru_cache
from pathlib import Path
from typing import Any

import numpy as np
import structlog
from gradio_client import Client, handle_file

from src.utils.config import settings
from src.utils.exceptions import ConfigurationError

logger = structlog.get_logger(__name__)


class STTService:
    """STT service using nvidia/canary-1b-v2 Gradio Space."""

    def __init__(self, api_url: str | None = None, hf_token: str | None = None) -> None:
        """Initialize STT service.

        Args:
            api_url: Gradio Space URL (default: settings.stt_api_url or nvidia/canary-1b-v2)
            hf_token: HuggingFace token for authenticated Spaces (default: None)

        Raises:
            ConfigurationError: If API URL not configured
        """
        self.api_url = api_url or settings.stt_api_url or "https://nvidia-canary-1b-v2.hf.space"
        if not self.api_url:
            raise ConfigurationError("STT API URL not configured")
        self.hf_token = hf_token
        self.client: Client | None = None

    async def _get_client(self, hf_token: str | None = None) -> Client:
        """Get or create Gradio Client (lazy initialization).

        Args:
            hf_token: HuggingFace token for authenticated Spaces (overrides instance token)

        Returns:
            Gradio Client instance
        """
        # Use provided token or instance token
        token = hf_token or self.hf_token

        # If client exists but token changed, recreate it
        if self.client is not None and token != self.hf_token:
            self.client = None

        if self.client is None:
            loop = asyncio.get_running_loop()
            # Pass token to Client for authenticated Spaces
            # Gradio Client uses 'token' parameter, not 'hf_token'
            if token:
                self.client = await loop.run_in_executor(
                    None,
                    lambda: Client(self.api_url, token=token),
                )
            else:
                self.client = await loop.run_in_executor(
                    None,
                    lambda: Client(self.api_url),
                )
            # Update instance token for future use
            self.hf_token = token
        return self.client

    async def transcribe_file(
        self,
        audio_path: str,
        source_lang: str | None = None,
        target_lang: str | None = None,
        hf_token: str | None = None,
    ) -> str:
        """Transcribe audio file using Gradio API.

        Args:
            audio_path: Path to audio file
            source_lang: Source language (default: settings.stt_source_lang)
            target_lang: Target language (default: settings.stt_target_lang)

        Returns:
            Transcribed text string

        Raises:
            ConfigurationError: If transcription fails
        """
        client = await self._get_client(hf_token=hf_token)
        source_lang = source_lang or settings.stt_source_lang
        target_lang = target_lang or settings.stt_target_lang

        logger.info(
            "transcribing_audio_file",
            audio_path=audio_path,
            source_lang=source_lang,
            target_lang=target_lang,
        )

        try:
            # Call /transcribe_file API endpoint
            # API returns: (dataframe, csv_path, srt_path)
            loop = asyncio.get_running_loop()
            result = await loop.run_in_executor(
                None,
                lambda: client.predict(
                    audio_path=handle_file(audio_path),
                    source_lang=source_lang,
                    target_lang=target_lang,
                    api_name="/transcribe_file",
                ),
            )

            # Extract transcription from result
            transcribed_text = self._extract_transcription(result)

            logger.info(
                "audio_transcription_complete",
                text_length=len(transcribed_text),
            )

            return transcribed_text

        except Exception as e:
            logger.error("audio_transcription_failed", error=str(e), error_type=type(e).__name__)
            raise ConfigurationError(f"Audio transcription failed: {e}") from e

    async def transcribe_audio(
        self,
        audio_data: tuple[int, np.ndarray[Any, Any]],  # type: ignore[type-arg]
        hf_token: str | None = None,
    ) -> str:
        """Transcribe audio numpy array to text.

        Args:
            audio_data: Tuple of (sample_rate, audio_array)

        Returns:
            Transcribed text string
        """
        sample_rate, audio_array = audio_data

        logger.info(
            "transcribing_audio_array",
            sample_rate=sample_rate,
            audio_shape=audio_array.shape,
        )

        # Save audio to temp file
        temp_path = self._save_audio_temp(audio_data)

        try:
            # Transcribe the temp file
            transcribed_text = await self.transcribe_file(temp_path, hf_token=hf_token)
            return transcribed_text
        finally:
            # Clean up temp file
            try:
                Path(temp_path).unlink(missing_ok=True)
            except Exception as e:
                logger.warning("failed_to_cleanup_temp_file", path=temp_path, error=str(e))

    def _extract_transcription(self, api_result: tuple[Any, ...]) -> str:
        """Extract transcription text from API result.

        Args:
            api_result: Tuple from Gradio API (dataframe, csv_path, srt_path)

        Returns:
            Extracted transcription text
        """
        # API returns: (dataframe, csv_path, srt_path)
        # Try to extract from dataframe first
        if isinstance(api_result, tuple) and len(api_result) >= 1:
            dataframe = api_result[0]
            if isinstance(dataframe, dict) and "data" in dataframe:
                # Extract text from dataframe rows
                rows = dataframe.get("data", [])
                if rows:
                    # Combine all text segments
                    text_segments = []
                    for row in rows:
                        if isinstance(row, list) and len(row) > 0:
                            # First column is usually the text
                            text_segments.append(str(row[0]))
                    if text_segments:
                        return " ".join(text_segments)

            # Fallback: try to read CSV file if available
            if len(api_result) >= 2 and api_result[1]:
                csv_path = api_result[1]
                try:
                    import pandas as pd

                    df = pd.read_csv(csv_path)
                    if "text" in df.columns:
                        return " ".join(df["text"].astype(str).tolist())
                    elif len(df.columns) > 0:
                        # Use first column
                        return " ".join(df.iloc[:, 0].astype(str).tolist())
                except Exception as e:
                    logger.warning("failed_to_read_csv", csv_path=csv_path, error=str(e))

        # Last resort: return empty string
        logger.warning("could_not_extract_transcription", result_type=type(api_result).__name__)
        return ""

    def _save_audio_temp(
        self,
        audio_data: tuple[int, np.ndarray[Any, Any]],  # type: ignore[type-arg]
    ) -> str:
        """Save audio numpy array to temporary WAV file.

        Args:
            audio_data: Tuple of (sample_rate, audio_array)

        Returns:
            Path to temporary WAV file
        """
        sample_rate, audio_array = audio_data

        # Create temp file
        temp_file = tempfile.NamedTemporaryFile(
            suffix=".wav",
            delete=False,
        )
        temp_path = temp_file.name
        temp_file.close()

        # Save audio using soundfile
        try:
            import soundfile as sf

            # Ensure audio is float32 and mono
            if audio_array.dtype != np.float32:
                audio_array = audio_array.astype(np.float32)

            # Handle stereo -> mono conversion
            if len(audio_array.shape) > 1:
                audio_array = np.mean(audio_array, axis=1)

            # Normalize to [-1, 1] range
            if audio_array.max() > 1.0 or audio_array.min() < -1.0:
                audio_array = audio_array / np.max(np.abs(audio_array))

            sf.write(temp_path, audio_array, sample_rate)

            logger.debug("saved_audio_temp", path=temp_path, sample_rate=sample_rate)

            return temp_path

        except ImportError:
            raise ConfigurationError(
                "soundfile not installed. Install with: uv add soundfile"
            ) from None
        except Exception as e:
            logger.error("failed_to_save_audio_temp", error=str(e))
            raise ConfigurationError(f"Failed to save audio to temp file: {e}") from e


@lru_cache(maxsize=1)
def get_stt_service() -> STTService:
    """Get or create singleton STT service instance.

    Returns:
        STTService instance
    """
    return STTService()