Spaces:
Sleeping
Sleeping
| # app.py | |
| import os, gc, warnings, logging | |
| import torch, numpy as np, librosa, gradio as gr | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline | |
| from huggingface_hub import login | |
| # ------------------------------- | |
| # HF Token Login (for private repos) | |
| # ------------------------------- | |
| if "HF_TOKEN" in os.environ: | |
| login(token=os.environ["HF_TOKEN"]) | |
| # ------------------------------- | |
| # Config & Device | |
| # ------------------------------- | |
| warnings.filterwarnings("ignore") | |
| logger = logging.getLogger("whisper_streaming") | |
| logger.setLevel(logging.DEBUG) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print(f"Using device: {device}, dtype={torch_dtype}") | |
| # ------------------------------- | |
| # Model Loading | |
| # ------------------------------- | |
| MODEL_OPTIONS = { | |
| "Fine-tuned Cantonese": "thomaskywong0131/whisper-large-v3-cantonese", | |
| "OpenAI Large-v3": "openai/whisper-large-v3", | |
| "OpenAI Large-v3-Turbo": "openai/whisper-large-v3-turbo", | |
| } | |
| def load_model(model_choice="Fine-tuned Cantonese"): | |
| model_name = MODEL_OPTIONS[model_choice] | |
| print(f"Loading model: {model_name}") | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| processor = WhisperProcessor.from_pretrained(model_name) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| model_name, | |
| dtype=torch_dtype, | |
| device_map="auto" if device == "cuda" else None, | |
| use_safetensors=True, | |
| ) | |
| pipe = pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| dtype=torch_dtype, | |
| generate_kwargs={"language": "yue"} # 強制指定粵語 | |
| ) | |
| print(f"✅ Successfully loaded: {model_choice}") | |
| return pipe, processor | |
| pipe, processor = load_model("Fine-tuned Cantonese") | |
| # ------------------------------- | |
| # HypothesisBuffer | |
| # ------------------------------- | |
| class HypothesisBuffer: | |
| def __init__(self): | |
| self.entries = [] | |
| def insert(self, new, offset=0): | |
| safe_new = [] | |
| for a, b, t in new: | |
| start = a + offset if a is not None else None | |
| end = b + offset if b is not None else None | |
| safe_new.append((start, end, t)) | |
| self.entries.extend(safe_new) | |
| def reset(self): | |
| self.entries = [] | |
| def get_text(self): | |
| return "".join([t for (_, _, t) in self.entries]) | |
| def get_entries(self): | |
| return self.entries | |
| def complete(self): | |
| return self.entries | |
| def flush(self): | |
| return self.entries | |
| # ------------------------------- | |
| # OnlineASRProcessor | |
| # ------------------------------- | |
| class OnlineASRProcessor: | |
| def __init__(self, pipe, processor, sample_rate=16000): | |
| self.pipe = pipe | |
| self.processor = processor | |
| self.sample_rate = sample_rate | |
| self.audio_accum = np.array([], dtype=np.float32) | |
| self.transcript_buffer = HypothesisBuffer() | |
| def init(self): | |
| self.audio_accum = np.array([], dtype=np.float32) | |
| self.transcript_buffer.reset() | |
| def insert_audio_chunk(self, audio: np.ndarray): | |
| self.audio_accum = np.append(self.audio_accum, audio) | |
| def process_iter(self): | |
| if len(self.audio_accum) < self.sample_rate: | |
| return None, None, "" | |
| try: | |
| result = self.pipe(self.audio_accum, chunk_length_s=10) | |
| txt = result["text"].strip() | |
| except Exception as e: | |
| txt = f"[ASR error: {e}]" | |
| if txt: | |
| self.transcript_buffer.insert([(None, None, txt)]) | |
| self.audio_accum = np.array([], dtype=np.float32) | |
| return None, None, txt | |
| return None, None, "" | |
| def finish(self): | |
| if len(self.audio_accum) == 0: | |
| return None, None, "" | |
| try: | |
| result = self.pipe(self.audio_accum, chunk_length_s=30) | |
| txt = result["text"].strip() | |
| except Exception as e: | |
| txt = f"[ASR error: {e}]" | |
| if txt: | |
| self.transcript_buffer.insert([(None, None, txt)]) | |
| self.audio_accum = np.array([], dtype=np.float32) | |
| return None, None, txt | |
| return None, None, "" | |
| # ------------------------------- | |
| # VACOnlineASRProcessor (Silero VAD) | |
| # ------------------------------- | |
| class VACOnlineASRProcessor: | |
| def __init__(self, pipe, processor, silence_sec=0.8, speech_threshold=0.5): | |
| self.online = OnlineASRProcessor(pipe, processor) | |
| self.model, _ = torch.hub.load( | |
| repo_or_dir="snakers4/silero-vad", | |
| model="silero_vad", | |
| force_reload=False | |
| ) | |
| self.sample_rate = 16000 | |
| self.frame_size = 512 | |
| self.silence_sec = silence_sec | |
| self.speech_threshold = speech_threshold | |
| self.reset() | |
| def reset(self): | |
| self.online.init() | |
| self.buffer = np.array([], dtype=np.float32) | |
| self.audio_accum = np.array([], dtype=np.float32) | |
| self.silence_samples = 0 | |
| self.flush_queue = [] | |
| def insert_audio_chunk(self, audio: np.ndarray): | |
| if audio.dtype != np.float32: | |
| audio = audio.astype(np.float32) | |
| if audio.max() > 1.0 or audio.min() < -1.0: | |
| audio /= 32768.0 | |
| self.buffer = np.append(self.buffer, audio) | |
| while len(self.buffer) >= self.frame_size: | |
| frame = self.buffer[:self.frame_size] | |
| self.buffer = self.buffer[self.frame_size:] | |
| tensor = torch.from_numpy(frame).unsqueeze(0) | |
| with torch.no_grad(): | |
| speech_prob = self.model(tensor, self.sample_rate).item() | |
| log_debug(f"[VAD] prob={speech_prob:.2f}, silence={self.silence_samples}, accum={len(self.audio_accum)}") | |
| if speech_prob > self.speech_threshold: | |
| self.audio_accum = np.append(self.audio_accum, frame) | |
| self.silence_samples = 0 | |
| else: | |
| self.silence_samples += self.frame_size | |
| if self.silence_samples >= self.sample_rate * self.silence_sec: | |
| if len(self.audio_accum) > 0: | |
| self.online.insert_audio_chunk(self.audio_accum) | |
| beg, end, txt = self.online.finish() | |
| if txt: | |
| self.flush_queue.append((beg, end, txt)) | |
| log_debug(f"[FLUSH] Added to queue: {txt}") | |
| self.audio_accum = np.array([], dtype=np.float32) | |
| self.silence_samples = 0 | |
| def process_iter(self): | |
| if self.flush_queue: | |
| return self.flush_queue.pop(0) | |
| return None, None, "" | |
| def finish(self): | |
| beg, end, txt = self.online.finish() | |
| if txt: | |
| return beg, end, txt | |
| return None, None, "" | |
| # ------------------------------- | |
| # Gradio Callbacks | |
| # ------------------------------- | |
| stream_text = "" | |
| debug_text = "" | |
| use_vac = False | |
| vac_online = None | |
| online = OnlineASRProcessor(pipe, processor) | |
| silence_sec_value = 0.8 | |
| speech_threshold_value = 0.5 | |
| def log_debug(msg): | |
| global debug_text | |
| debug_text += msg + "\n" | |
| def start_transcription(vac_mode, silence_sec, speech_threshold): | |
| global stream_text, debug_text, use_vac, vac_online, online | |
| global silence_sec_value, speech_threshold_value | |
| stream_text, debug_text = "", "" | |
| use_vac = vac_mode | |
| silence_sec_value = silence_sec | |
| speech_threshold_value = speech_threshold | |
| if use_vac: | |
| vac_online = VACOnlineASRProcessor( | |
| pipe, processor, | |
| silence_sec=silence_sec_value, | |
| speech_threshold=speech_threshold_value | |
| ) | |
| vac_online.reset() | |
| log_debug("[START] VAC mode enabled") | |
| else: | |
| online.init() | |
| log_debug("[START] VAC mode disabled (basic streaming)") | |
| log_debug(f"[SETTINGS] silence_sec={silence_sec_value:.2f}, speech_threshold={speech_threshold_value:.2f}") | |
| return "🔴 Streaming started", gr.update(interactive=False), gr.update(interactive=True), debug_text | |
| def stop_transcription(): | |
| return "⏹️ Stopped", gr.update(interactive=True), gr.update(interactive=False), stream_text, debug_text | |
| def process_stream(audio): | |
| global stream_text, debug_text, use_vac, vac_online, online | |
| if audio is None: | |
| return stream_text, debug_text | |
| if isinstance(audio, tuple): | |
| sr, arr = audio | |
| arr = np.array(arr) | |
| if arr.dtype != np.float32: | |
| arr = arr.astype(np.float32) | |
| if arr.max() > 1.0 or arr.min() < -1.0: | |
| arr /= 32768.0 | |
| if sr != 16000: | |
| arr = librosa.resample(arr, orig_sr=sr, target_sr=16000) | |
| else: | |
| arr = np.array(audio, dtype=np.float32) | |
| if use_vac: | |
| vac_online.insert_audio_chunk(arr) | |
| beg, end, txt = vac_online.process_iter() | |
| log_debug(f"[VAC] Insert {len(arr)} samples | Output: {txt}") | |
| else: | |
| online.insert_audio_chunk(arr) | |
| beg, end, txt = online.process_iter() | |
| log_debug(f"[Online] Insert {len(arr)} samples | Output: {txt}") | |
| if txt: | |
| stream_text += txt + "\n" | |
| log_debug(f"[Flush] {beg}-{end} | '{txt}'") | |
| return stream_text, debug_text | |
| def clear_text(): | |
| global stream_text, debug_text | |
| stream_text = "" | |
| debug_text = "" | |
| return stream_text, debug_text | |
| # ------------------------------- | |
| # Gradio UI | |
| # ------------------------------- | |
| with gr.Blocks(title="Cantonese Streaming (VAC)", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎤 Cantonese Streaming Transcription with VAC + Debug Logs") | |
| gr.Markdown("✅ 支援 VAC,並可在下方調整靜音閾值與語音閾值") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| vac_mode = gr.Checkbox(label="啟用 VAC 模式", value=False) | |
| silence_slider = gr.Slider(label="靜音閾值 (秒)", minimum=0.3, maximum=1.2, value=0.8, step=0.1) | |
| threshold_slider = gr.Slider(label="語音閾值", minimum=0.1, maximum=0.9, value=0.5, step=0.05) | |
| start_btn = gr.Button("🔴 Start") | |
| stop_btn = gr.Button("⏹️ Stop", interactive=False) | |
| clear_btn = gr.Button("🗑️ Clear") | |
| with gr.Column(scale=2): | |
| mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="🎙️ Live Input") | |
| output = gr.Textbox(label="📝 Transcript", lines=15, autoscroll=True) | |
| debug_output = gr.Textbox(label="🔎 Debug Window", lines=15, autoscroll=True) | |
| start_btn.click(start_transcription, inputs=[vac_mode, silence_slider, threshold_slider], | |
| outputs=[output, start_btn, stop_btn, debug_output]) | |
| stop_btn.click(stop_transcription, outputs=[output, start_btn, stop_btn, output, debug_output]) | |
| clear_btn.click(clear_text, outputs=[output, debug_output]) | |
| mic.stream(process_stream, inputs=[mic], outputs=[output, debug_output], stream_every=0.5) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| ssr_mode=False) # 關閉 SSR |