File size: 14,455 Bytes
fe82c06
 
b672ef4
f4a0156
fe82c06
 
 
f4a0156
91549c5
 
fe82c06
 
f4a0156
 
3361f15
 
f4a0156
 
 
 
 
 
91549c5
f4a0156
fe82c06
f4a0156
 
 
7d46a3c
f4a0156
 
033af1b
fe82c06
 
b672ef4
033af1b
fe82c06
033af1b
fe82c06
033af1b
 
fe82c06
 
033af1b
fe82c06
 
f4a0156
fe82c06
 
033af1b
f4a0156
 
fe82c06
f4a0156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
033af1b
f4a0156
 
 
 
 
 
 
 
 
 
 
 
 
033af1b
f4a0156
 
 
 
 
 
 
 
 
 
 
 
033af1b
fe82c06
 
f4a0156
 
 
 
 
 
 
fe82c06
 
f4a0156
033af1b
fe82c06
 
033af1b
 
fe82c06
 
 
 
 
 
 
f4a0156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
033af1b
f4a0156
 
 
309ccf7
033af1b
f4a0156
 
 
 
 
 
 
 
309ccf7
f4a0156
 
033af1b
f4a0156
 
 
 
 
033af1b
 
3361f15
 
 
 
033af1b
3361f15
f4a0156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3361f15
f4a0156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3361f15
 
f4a0156
3361f15
f4a0156
 
 
ed6cecc
f4a0156
 
ed6cecc
 
3361f15
 
f4a0156
 
3361f15
 
d87796a
f4a0156
3361f15
 
 
f4a0156
 
d87796a
3c50bb0
 
 
b672ef4
309ccf7
 
3c50bb0
309ccf7
3c50bb0
309ccf7
3c50bb0
309ccf7
3c50bb0
 
f4a0156
 
309ccf7
f4a0156
033af1b
fe82c06
 
 
 
b672ef4
f4a0156
fe82c06
00c3484
 
 
4e57f03
00c3484
 
e261620
4e57f03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
#!/usr/bin/env python3
"""
Ringg Parrot STT V1 🦜 - Hugging Face Space (Frontend)
Real-time streaming transcription using Gradio's audio streaming.
"""

import os
import tempfile
from pathlib import Path

import gradio as gr
import requests
import numpy as np
import soundfile as sf
from dotenv import load_dotenv

try:
    import librosa
    HAS_LIBROSA = True
except ImportError:
    HAS_LIBROSA = False
    print("⚠️ librosa not installed. Install with: pip install librosa")

load_dotenv()

# Backend API endpoint
API_ENDPOINT = os.environ.get("STT_API_ENDPOINT", "http://localhost:7864")
TARGET_SAMPLE_RATE = 16000

# How often to transcribe (in seconds of audio)
MIN_AUDIO_LENGTH = 0.4  # Transcribe when we have at least 400ms of new audio


class RinggSTTClient:
    """Client for Ringg Parrot STT API"""

    def __init__(self, api_endpoint: str):
        self.api_endpoint = api_endpoint.rstrip("/")
        self.session = requests.Session()
        self.session.headers.update({"User-Agent": "RinggSTT-HF-Space/1.0"})

    def check_health(self) -> dict:
        try:
            response = self.session.get(f"{self.api_endpoint}/health", timeout=5)
            if response.status_code == 200:
                return {"status": "healthy", "message": "βœ… API is online"}
            return {"status": "error", "message": f"❌ API returned status {response.status_code}"}
        except Exception as e:
            return {"status": "error", "message": f"❌ Error: {str(e)}"}

    def transcribe_audio_data(self, audio_data: np.ndarray, sample_rate: int, language: str = "hi") -> str:
        """Transcribe audio data (numpy array) via multipart upload API"""
        try:
            # Save to temporary WAV file
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
                temp_path = f.name
                sf.write(temp_path, audio_data, sample_rate)
            
            try:
                with open(temp_path, "rb") as f:
                    files = {"file": ("audio.wav", f, "audio/wav")}
                    data = {"language": language, "punctuate": "false"}
                    response = self.session.post(
                        f"{self.api_endpoint}/v1/audio/transcriptions",
                        files=files,
                        data=data,
                        timeout=30,
                    )
                
                # Debug: log the response for troubleshooting
                print(
                    f"[transcribe_audio_data] status={response.status_code} "
                    f"body={response.text[:500]}"
                )

                if response.status_code == 200:
                    result = response.json()
                    if "transcription_channel_0" in result:
                        return result.get("transcription_channel_0", "")
                    return result.get("transcription", "")
                else:
                    return ""
            finally:
                os.unlink(temp_path)
                
        except Exception as e:
            print(f"Transcription error: {e}")
            return ""

    def transcribe_file(self, audio_file_path: str, language: str = "hi") -> str:
        """Transcribe audio file via multipart upload API"""
        try:
            with open(audio_file_path, "rb") as f:
                files = {"file": (Path(audio_file_path).name, f)}
                data = {"language": language, "punctuate": "false"}
                response = self.session.post(
                    f"{self.api_endpoint}/v1/audio/transcriptions",
                    files=files,
                    data=data,
                    timeout=120,
                )

            if response.status_code == 200:
                result = response.json()
                if "transcription_channel_0" in result:
                    transcripts = []
                    if result.get("transcription_channel_0"):
                        transcripts.append(result["transcription_channel_0"])
                    if result.get("transcription_channel_1"):
                        transcripts.append(f"\n[Channel 2]: {result['transcription_channel_1']}")
                    return "".join(transcripts) if transcripts else "No speech detected"
                return result.get("transcription", "No transcription received")
            else:
                return f"❌ API Error: {response.status_code}"

        except Exception as e:
            return f"❌ Error: {str(e)}"


# Initialize API client
print(f"πŸ”— Connecting to STT API: {API_ENDPOINT}")
stt_client = RinggSTTClient(API_ENDPOINT)
health_status = stt_client.check_health()
print(f"API Health: {health_status}")


def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
    """Resample audio to target sample rate"""
    if orig_sr == target_sr:
        return audio
    
    if HAS_LIBROSA:
        return librosa.resample(audio.astype(np.float64), orig_sr=orig_sr, target_sr=target_sr)
    else:
        # Simple linear interpolation fallback
        duration = len(audio) / orig_sr
        new_length = int(duration * target_sr)
        indices = np.linspace(0, len(audio) - 1, new_length)
        return np.interp(indices, np.arange(len(audio)), audio.astype(np.float64))


def transcribe_stream(audio, language, audio_buffer, last_transcription, samples_processed):
    """
    Process streaming audio from microphone.
    
    Simplified approach:
    - Accumulate ALL audio chunks
    - When we have enough new audio, transcribe the ENTIRE recording
    - Display the complete transcription (backend handles everything)
    """
    # Initialize states
    if audio_buffer is None:
        audio_buffer = []
    if last_transcription is None:
        last_transcription = ""
    if samples_processed is None:
        samples_processed = 0
    
    # Handle invalid audio input
    if audio is None or isinstance(audio, int):
        display = last_transcription if last_transcription else "🎀 Click microphone to start..."
        return display, audio_buffer, last_transcription, samples_processed
    
    # Gradio streaming returns (sample_rate, audio_data)
    if not isinstance(audio, tuple) or len(audio) != 2:
        display = last_transcription if last_transcription else "🎀 Listening..."
        return display, audio_buffer, last_transcription, samples_processed
    
    sample_rate, audio_data = audio
    
    if not isinstance(audio_data, np.ndarray) or len(audio_data) == 0:
        display = last_transcription if last_transcription else "🎀 Listening..."
        return display, audio_buffer, last_transcription, samples_processed
    
    # Convert stereo to mono if needed
    if len(audio_data.shape) > 1:
        audio_data = np.mean(audio_data, axis=1)
    
    # Append this chunk to buffer
    audio_buffer.append(audio_data.copy())
    
    # Calculate total samples we have now
    total_samples = sum(len(arr) for arr in audio_buffer)
    total_duration = total_samples / sample_rate
    
    # Calculate new audio since last transcription
    new_samples = total_samples - samples_processed
    new_duration = new_samples / sample_rate
    
    # Only transcribe if we have enough NEW audio (to avoid too frequent API calls)
    if new_duration < MIN_AUDIO_LENGTH:
        display = last_transcription if last_transcription else f"🎀 Recording... ({total_duration:.1f}s)"
        return display, audio_buffer, last_transcription, samples_processed
    
    try:
        # Concatenate ALL buffered audio
        full_audio = np.concatenate(audio_buffer)
        
        # Resample to 16kHz if needed
        if sample_rate != TARGET_SAMPLE_RATE:
            full_audio = resample_audio(full_audio, sample_rate, TARGET_SAMPLE_RATE)
        
        # Normalize audio
        max_val = np.max(np.abs(full_audio))
        if max_val > 0:
            full_audio = full_audio / max_val * 0.95
        
        # Get language code
        lang_code = "hi" if language == "Hindi" else "en"
        
        # Transcribe the ENTIRE audio
        transcription = stt_client.transcribe_audio_data(
            full_audio.astype(np.float32), 
            TARGET_SAMPLE_RATE, 
            lang_code
        )
        
        # Update state
        if transcription.strip():
            last_transcription = transcription
        
        # Mark all current samples as processed
        samples_processed = total_samples
        
        display = last_transcription if last_transcription else f"🎀 Recording... ({total_duration:.1f}s)"
        return display, audio_buffer, last_transcription, samples_processed
        
    except Exception as e:
        print(f"Processing error: {e}")
        display = last_transcription if last_transcription else "🎀 Listening..."
        return display, audio_buffer, last_transcription, samples_processed


def clear_transcription():
    """Clear all transcription state"""
    return "🎀 Click microphone to start...", None, "", 0


def transcribe_file(audio_file, language):
    """Transcribe uploaded audio file"""
    if audio_file is None:
        return "⚠️ Please upload an audio file to transcribe."
    
    lang_code = "hi" if language == "Hindi" else "en"
    transcription = stt_client.transcribe_file(audio_file, lang_code)
    text = (transcription or "").strip()

    if not text or text.startswith("❌") or text.startswith("⏱"):
        return text or "⚠️ No speech detectedβ€”try a clearer recording."

    return text


def create_interface():
    """Create Gradio interface"""

    with gr.Blocks(
        theme=gr.themes.Base(
            font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
        ),
        css=".gradio-container {max-width: none !important;}",
    ) as demo:
        gr.HTML("""
            <div style="display: flex; align-items: center; gap: 10px;">
                <img style="width: 50px; height: 50px; background-color: white; border-radius: 10%;" 
                     src="https://storage.googleapis.com/desivocal-prod/desi-vocal/ringg.svg" alt="Logo">
                <h1 style="margin: 0;">Ringg Parrot STT V1.0 🦜</h1>
            </div>
        """)

        # Real-time streaming section
        gr.Markdown("""
            ## 🎀 Real-time Transcription
            Click the microphone to start recording. Transcription updates as you speak.
            
            *The entire recording is transcribed each time, so text may refine as more context is added.*
        """)
        
        # States for streaming
        audio_buffer = gr.State(None)
        last_transcription = gr.State("")
        samples_processed = gr.State(0)
        
        with gr.Row():
            with gr.Column(scale=1):
                stream_language = gr.Dropdown(
                    choices=["Hindi", "English"],
                    value="Hindi",
                    label="Language",
                )
                audio_input = gr.Audio(
                    sources=["microphone"],
                    type="numpy",
                    streaming=True,
                    label="🎀 Click to start recording",
                )
                clear_btn = gr.Button("πŸ—‘οΈ Clear & Reset", variant="secondary")
            
            with gr.Column(scale=2):
                text_output = gr.Textbox(
                    label="Transcription",
                    value="🎀 Click microphone to start...",
                    lines=10,
                    interactive=False,
                )
        
        # Wire up streaming
        audio_input.stream(
            fn=transcribe_stream,
            inputs=[audio_input, stream_language, audio_buffer, last_transcription, samples_processed],
            outputs=[text_output, audio_buffer, last_transcription, samples_processed],
        )
        
        # Clear button
        clear_btn.click(
            fn=clear_transcription,
            inputs=[],
            outputs=[text_output, audio_buffer, last_transcription, samples_processed],
        )

        gr.Markdown("<br>")
        
        # File upload section
        gr.Markdown("""
            ## πŸ“ Upload an audio file for transcription
            Supports WAV, MP3, FLAC, M4A, and more.
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                file_language = gr.Dropdown(
                    choices=["Hindi", "English"],
                    value="Hindi",
                    label="Language",
                )
                file_input = gr.Audio(
                    type="filepath",
                    sources=["upload"],
                    label="Upload Audio",
                )
                transcribe_btn = gr.Button("Transcribe File", variant="primary", size="lg")
            
            with gr.Column(scale=2):
                file_output = gr.Textbox(
                    label="Transcription",
                    lines=8,
                    interactive=False,
                )

        transcribe_btn.click(
            fn=transcribe_file,
            inputs=[file_input, file_language],
            outputs=file_output,
        )

        gr.Markdown("""
            <br>
            
            ## 🎯 Performance Benchmarks
            **Ringg Parrot STT V1** Ranks **1st** Among Top Models.
        """)

        with gr.Row():
            gr.DataFrame(
                value=[
                    ["Parrot STT (Ringg AI)", "15.00%", "15.92%"],
                    ["IndicWav2Vec ", "19.35%", "20.91%"],
                    ["VakyanSh Wav2Vec2", "22.73%", "24.78%"],
                ],
                headers=["Model", "Median WER ↓", "Mean WER ↓"],
                datatype=["str", "str", "str"],
                row_count=3,
                col_count=(3, "fixed"),
                interactive=False,
            )

        gr.Markdown("""
            ## πŸ™ Acknowledgements
            - Built with [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) models
        """)

    return demo


if __name__ == "__main__":
    print("🌐 Launching Ringg Parrot STT V1 Gradio Interface...")
    print(f"Backend API: {API_ENDPOINT}")
    demo = create_interface()
    demo.queue(default_concurrency_limit=2, max_size=20)
    demo.launch(
        share=False,
        server_name="0.0.0.0",
        server_port=7860,
        debug=True,
        show_api=False,
    )