WJ88 commited on
Commit
e6059ff
·
verified ·
1 Parent(s): 45979c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -170
app.py CHANGED
@@ -1,183 +1,97 @@
1
  import gradio as gr
2
- import nemo.collections.asr as nemo_asr
3
  import numpy as np
4
- from pydub import AudioSegment
5
- from pydub.silence import detect_silence
6
- import warnings
7
- import torch
8
- import logging
9
- import io
10
  import os
11
- import datetime
12
-
13
- warnings.filterwarnings("ignore")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Setup file-based logging for persistence
16
- LOG_FILE = "/tmp/app_logs.txt"
17
- logging.basicConfig(
18
- level=logging.INFO,
19
- format='%(asctime)s - %(levelname)s - %(message)s',
20
- handlers=[
21
- logging.FileHandler(LOG_FILE, mode='a'),
22
- logging.StreamHandler() # Also to console for HF logs
23
- ]
 
 
 
24
  )
25
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
26
 
27
- def append_log(message):
28
- """Append log message to file and return updated log content."""
29
- logger.info(message)
30
  try:
31
- with open(LOG_FILE, 'r') as f:
32
- logs = f.read()
33
- except FileNotFoundError:
34
- logs = ""
35
- return logs
36
-
37
- # Global model loader
38
- model = None
39
-
40
- def load_model():
41
- global model
42
- if model is None:
43
- logger.info("Loading Parakeet v3 model...")
44
- model = nemo_asr.models.ASRModel.from_pretrained(
45
- model_name="nvidia/parakeet-tdt-0.6b-v3",
46
- map_location="cpu"
47
- )
48
- model.eval()
49
- logger.info("Model loaded successfully.")
50
- return model
51
-
52
- class TranscriptionState:
53
- def __init__(self):
54
- self.buffer = None # AudioSegment
55
- self.text = ""
 
 
 
 
 
56
 
57
- def transcribe_segment(segment_array: np.ndarray):
58
- """Transcribe a normalized audio segment."""
59
- load_model()
60
- logger.info(f"Transcribing segment of length {len(segment_array)} samples.")
61
- with torch.no_grad(), warnings.catch_warnings():
62
- warnings.simplefilter("ignore")
63
- output = model.transcribe([segment_array])
64
- logger.info(f"Transcription complete: '{output[0][:50]}...'")
65
- return output[0]
66
-
67
- def process_live_audio(chunk_bytes, state: TranscriptionState):
68
- """Process live mic PCM bytes chunk with VAD and buffer management."""
69
- if chunk_bytes is None or len(chunk_bytes) == 0:
70
- logger.debug("Empty chunk received.")
71
- return state.text, state, append_log("Empty chunk skipped.")
72
-
73
- chunk_size = len(chunk_bytes)
74
- logger.debug(f"Received chunk of {chunk_size} bytes.")
75
-
76
- # Create AudioSegment from raw PCM bytes (16kHz mono int16)
77
- try:
78
- new_segment = AudioSegment(
79
- data=chunk_bytes,
80
- frame_rate=16000,
81
- sample_width=2,
82
- channels=1
83
- )
84
  except Exception as e:
85
- logger.error(f"Chunk creation error: {e}")
86
- return state.text, state, append_log(f"Chunk error: {e}")
87
-
88
- # Append to buffer
89
- if state.buffer is None:
90
- state.buffer = new_segment
91
- logger.debug("Initialized new buffer.")
92
- else:
93
- state.buffer += new_segment
94
-
95
- buffer_dur = state.buffer.duration_seconds
96
- logger.debug(f"Buffer duration: {buffer_dur:.1f}s")
97
-
98
- # Trim buffer to prevent accumulation (keep last 60s)
99
- if buffer_dur > 60:
100
- logger.info("Buffer exceeded 60s; trimming and re-transcribing.")
101
- full_array = np.array(state.buffer.get_array_of_samples(), dtype=np.float32) / 32768.0
102
- state.text = transcribe_segment(full_array)
103
- state.buffer = state.buffer[-30000:]
104
- return state.text, state, append_log("Buffer trimmed at 60s.")
105
-
106
- # VAD: Detect pauses in current buffer
107
- silent_windows = detect_silence(
108
- state.buffer,
109
- min_silence_len=500, # 0.5s pause
110
- silence_thresh=-40 # dB threshold
111
- )
112
-
113
- if len(silent_windows) > 0:
114
- last_silence_end = silent_windows[-1][1]
115
- if last_silence_end < len(state.buffer):
116
- logger.info(f"VAD detected pause at {last_silence_end}ms; transcribing up to pause.")
117
- segment = state.buffer[:last_silence_end]
118
- segment_array = np.array(segment.get_array_of_samples(), dtype=np.float32) / 32768.0
119
- partial_text = transcribe_segment(segment_array)
120
- state.text = partial_text
121
- state.buffer = state.buffer[last_silence_end:]
122
- return state.text, state, append_log(f"VAD update: Pause detected, transcribed '{partial_text[:50]}...'")
123
-
124
- return state.text, state, append_log(f"Chunk appended; buffer at {buffer_dur:.1f}s, awaiting pause.")
125
-
126
- def clear_session(state: TranscriptionState):
127
- """Reset session."""
128
- state.buffer = None
129
- state.text = ""
130
- logger.info("Session cleared by user.")
131
- return "", state, append_log("Session cleared.")
132
-
133
- # Gradio UI (mic-only)
134
- with gr.Blocks(title="Parakeet v3 Real-Time Mic Transcription") as demo:
135
- gr.Markdown(
136
- """
137
- # NVIDIA Parakeet-TDT 0.6B v3 Real-Time Transcription
138
- Speak continuously into the microphone—transcription updates live on natural pauses (0.5s+). Supports 25 European languages automatically. Optimized for CPU.
139
- """
140
- )
141
 
142
- state = gr.State(TranscriptionState())
143
- audio_input = gr.Audio(
144
- sources=["microphone"],
145
- type="bytes",
146
- streaming=True,
147
- label="Speak now—updates on pauses",
148
- waveform_options={"show_recording_waveform": True}
149
- )
150
- output_text = gr.Textbox(
151
- label="Live Transcription",
152
- lines=10,
153
- interactive=False
154
- )
155
- log_text = gr.Textbox(
156
- label="Debug Logs (Persistent)",
157
- lines=15,
158
- interactive=False,
159
- show_copy_button=True
160
- )
161
- clear_btn = gr.Button("Clear Session", variant="secondary")
162
 
163
- # Stream updates on each chunk
164
- audio_input.change(
165
- process_live_audio,
166
- inputs=[audio_input, state],
167
- outputs=[output_text, state, log_text],
168
- show_progress="minimal"
169
- )
170
- clear_btn.click(
171
- clear_session,
172
- inputs=state,
173
- outputs=[output_text, state, log_text]
174
- )
175
 
176
- gr.Markdown(
177
- """
178
- **Tips:** Speak clearly with brief pauses for instant updates. Long monologues auto-update every 60s. Logs show real-time debug info.
179
- """
180
- )
181
 
182
- if __name__ == "__main__":
183
- demo.launch(share=False, debug=True)
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ import sherpa_onnx
4
+ import time
 
 
 
 
5
  import os
6
+ import urllib.request
7
+ import tarfile
8
+
9
+ # Download and extract model if not present
10
+ model_dir = "sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8"
11
+ if not os.path.exists(model_dir):
12
+ url = "https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-parakeet-tdt-0.6b-v3-int8.tar.bz2"
13
+ urllib.request.urlretrieve(url, "model.tar.bz2")
14
+ with tarfile.open("model.tar.bz2") as tar:
15
+ tar.extractall()
16
+ os.remove("model.tar.bz2")
17
+
18
+ # Configure endpoint detection for natural pauses
19
+ endpoint_config = sherpa_onnx.EndpointConfig(
20
+ rule1_min_trailing_silence=1.0, # Activate on 1s silence
21
+ rule2_min_trailing_silence=0.5, # After speech, 0.5s silence
22
+ rule3_min_utterance_length=30.0 # Max 30s utterance
23
+ )
24
 
25
+ # Create OnlineRecognizer
26
+ config = sherpa_onnx.OnlineRecognizerConfig(
27
+ feat_config=sherpa_onnx.FeatureConfig(sample_rate=16000),
28
+ model_config=sherpa_onnx.OnlineTransducerModelConfig(
29
+ encoder=os.path.join(model_dir, "encoder.int8.onnx"),
30
+ decoder=os.path.join(model_dir, "decoder.int8.onnx"),
31
+ joiner=os.path.join(model_dir, "joiner.int8.onnx")
32
+ ),
33
+ tokens=os.path.join(model_dir, "tokens.txt"),
34
+ provider="cpu",
35
+ num_threads=2, # Match HF free-tier cores
36
+ endpoint_config=endpoint_config
37
  )
38
+ recognizer = sherpa_onnx.OnlineRecognizer(config)
39
+
40
+ def transcribe(state, audio_chunk):
41
+ if state is None:
42
+ state = {
43
+ "stream": recognizer.create_stream(),
44
+ "transcript": "",
45
+ "current_partial": "",
46
+ "log": "",
47
+ "last_time": time.time()
48
+ }
49
 
 
 
 
50
  try:
51
+ sr, y = audio_chunk
52
+ if y.ndim > 1:
53
+ y = np.mean(y, axis=1)
54
+ y = y.astype(np.float32)
55
+ if np.max(np.abs(y)) > 0:
56
+ y /= np.max(np.abs(y)) # Normalize to [-1, 1]
57
+ else:
58
+ state["log"] += "Weak signal detected.\n"
59
+ return state, state["transcript"] + state["current_partial"], state["log"]
60
+
61
+ state["stream"].accept_waveform(sr, y)
62
+
63
+ while recognizer.is_ready(state["stream"]):
64
+ recognizer.decode_stream(state["stream"])
65
+
66
+ result = recognizer.get_result(state["stream"])
67
+ current_text = result.text.strip()
68
+
69
+ if current_text != state["current_partial"]:
70
+ state["current_partial"] = current_text
71
+ latency = time.time() - state["last_time"]
72
+ state["log"] += f"Partial update (latency: {latency:.2f}s): {current_text}\n"
73
+ state["last_time"] = time.time()
74
+
75
+ if recognizer.is_endpoint(state["stream"]):
76
+ if current_text:
77
+ state["transcript"] += current_text + " "
78
+ state["log"] += f"Endpoint detected, committed: {current_text}\n"
79
+ recognizer.reset(state["stream"])
80
+ state["current_partial"] = ""
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  except Exception as e:
83
+ state["log"] += f"Error: {str(e)}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ return state, state["transcript"] + state["current_partial"], state["log"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("# Real-Time Multilingual Microphone Transcription")
89
+ with gr.Row():
90
+ audio = gr.Audio(source="microphone", type="numpy", streaming=True, label="Speak here")
91
+ transcript = gr.Textbox(label="Transcription", interactive=False)
92
+ logs = gr.Textbox(label="Debug Logs", interactive=False, lines=5)
93
+ state = gr.State()
 
 
 
 
 
94
 
95
+ audio.stream(transcribe, [state, audio], [state, transcript, logs])
 
 
 
 
96
 
97
+ demo.launch()