thomaskywong0131's picture
Upload folder using huggingface_hub
43602d3 verified
# app.py
import os, gc, warnings, logging
import torch, numpy as np, librosa, gradio as gr
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
from huggingface_hub import login
# -------------------------------
# HF Token Login (for private repos)
# -------------------------------
if "HF_TOKEN" in os.environ:
login(token=os.environ["HF_TOKEN"])
# -------------------------------
# Config & Device
# -------------------------------
warnings.filterwarnings("ignore")
logger = logging.getLogger("whisper_streaming")
logger.setLevel(logging.DEBUG)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Using device: {device}, dtype={torch_dtype}")
# -------------------------------
# Model Loading
# -------------------------------
MODEL_OPTIONS = {
"Fine-tuned Cantonese": "thomaskywong0131/whisper-large-v3-cantonese",
"OpenAI Large-v3": "openai/whisper-large-v3",
"OpenAI Large-v3-Turbo": "openai/whisper-large-v3-turbo",
}
def load_model(model_choice="Fine-tuned Cantonese"):
model_name = MODEL_OPTIONS[model_choice]
print(f"Loading model: {model_name}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(
model_name,
dtype=torch_dtype,
device_map="auto" if device == "cuda" else None,
use_safetensors=True,
)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
dtype=torch_dtype,
generate_kwargs={"language": "yue"} # 強制指定粵語
)
print(f"✅ Successfully loaded: {model_choice}")
return pipe, processor
pipe, processor = load_model("Fine-tuned Cantonese")
# -------------------------------
# HypothesisBuffer
# -------------------------------
class HypothesisBuffer:
def __init__(self):
self.entries = []
def insert(self, new, offset=0):
safe_new = []
for a, b, t in new:
start = a + offset if a is not None else None
end = b + offset if b is not None else None
safe_new.append((start, end, t))
self.entries.extend(safe_new)
def reset(self):
self.entries = []
def get_text(self):
return "".join([t for (_, _, t) in self.entries])
def get_entries(self):
return self.entries
def complete(self):
return self.entries
def flush(self):
return self.entries
# -------------------------------
# OnlineASRProcessor
# -------------------------------
class OnlineASRProcessor:
def __init__(self, pipe, processor, sample_rate=16000):
self.pipe = pipe
self.processor = processor
self.sample_rate = sample_rate
self.audio_accum = np.array([], dtype=np.float32)
self.transcript_buffer = HypothesisBuffer()
def init(self):
self.audio_accum = np.array([], dtype=np.float32)
self.transcript_buffer.reset()
def insert_audio_chunk(self, audio: np.ndarray):
self.audio_accum = np.append(self.audio_accum, audio)
def process_iter(self):
if len(self.audio_accum) < self.sample_rate:
return None, None, ""
try:
result = self.pipe(self.audio_accum, chunk_length_s=10)
txt = result["text"].strip()
except Exception as e:
txt = f"[ASR error: {e}]"
if txt:
self.transcript_buffer.insert([(None, None, txt)])
self.audio_accum = np.array([], dtype=np.float32)
return None, None, txt
return None, None, ""
def finish(self):
if len(self.audio_accum) == 0:
return None, None, ""
try:
result = self.pipe(self.audio_accum, chunk_length_s=30)
txt = result["text"].strip()
except Exception as e:
txt = f"[ASR error: {e}]"
if txt:
self.transcript_buffer.insert([(None, None, txt)])
self.audio_accum = np.array([], dtype=np.float32)
return None, None, txt
return None, None, ""
# -------------------------------
# VACOnlineASRProcessor (Silero VAD)
# -------------------------------
class VACOnlineASRProcessor:
def __init__(self, pipe, processor, silence_sec=0.8, speech_threshold=0.5):
self.online = OnlineASRProcessor(pipe, processor)
self.model, _ = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
force_reload=False
)
self.sample_rate = 16000
self.frame_size = 512
self.silence_sec = silence_sec
self.speech_threshold = speech_threshold
self.reset()
def reset(self):
self.online.init()
self.buffer = np.array([], dtype=np.float32)
self.audio_accum = np.array([], dtype=np.float32)
self.silence_samples = 0
self.flush_queue = []
def insert_audio_chunk(self, audio: np.ndarray):
if audio.dtype != np.float32:
audio = audio.astype(np.float32)
if audio.max() > 1.0 or audio.min() < -1.0:
audio /= 32768.0
self.buffer = np.append(self.buffer, audio)
while len(self.buffer) >= self.frame_size:
frame = self.buffer[:self.frame_size]
self.buffer = self.buffer[self.frame_size:]
tensor = torch.from_numpy(frame).unsqueeze(0)
with torch.no_grad():
speech_prob = self.model(tensor, self.sample_rate).item()
log_debug(f"[VAD] prob={speech_prob:.2f}, silence={self.silence_samples}, accum={len(self.audio_accum)}")
if speech_prob > self.speech_threshold:
self.audio_accum = np.append(self.audio_accum, frame)
self.silence_samples = 0
else:
self.silence_samples += self.frame_size
if self.silence_samples >= self.sample_rate * self.silence_sec:
if len(self.audio_accum) > 0:
self.online.insert_audio_chunk(self.audio_accum)
beg, end, txt = self.online.finish()
if txt:
self.flush_queue.append((beg, end, txt))
log_debug(f"[FLUSH] Added to queue: {txt}")
self.audio_accum = np.array([], dtype=np.float32)
self.silence_samples = 0
def process_iter(self):
if self.flush_queue:
return self.flush_queue.pop(0)
return None, None, ""
def finish(self):
beg, end, txt = self.online.finish()
if txt:
return beg, end, txt
return None, None, ""
# -------------------------------
# Gradio Callbacks
# -------------------------------
stream_text = ""
debug_text = ""
use_vac = False
vac_online = None
online = OnlineASRProcessor(pipe, processor)
silence_sec_value = 0.8
speech_threshold_value = 0.5
def log_debug(msg):
global debug_text
debug_text += msg + "\n"
def start_transcription(vac_mode, silence_sec, speech_threshold):
global stream_text, debug_text, use_vac, vac_online, online
global silence_sec_value, speech_threshold_value
stream_text, debug_text = "", ""
use_vac = vac_mode
silence_sec_value = silence_sec
speech_threshold_value = speech_threshold
if use_vac:
vac_online = VACOnlineASRProcessor(
pipe, processor,
silence_sec=silence_sec_value,
speech_threshold=speech_threshold_value
)
vac_online.reset()
log_debug("[START] VAC mode enabled")
else:
online.init()
log_debug("[START] VAC mode disabled (basic streaming)")
log_debug(f"[SETTINGS] silence_sec={silence_sec_value:.2f}, speech_threshold={speech_threshold_value:.2f}")
return "🔴 Streaming started", gr.update(interactive=False), gr.update(interactive=True), debug_text
def stop_transcription():
return "⏹️ Stopped", gr.update(interactive=True), gr.update(interactive=False), stream_text, debug_text
def process_stream(audio):
global stream_text, debug_text, use_vac, vac_online, online
if audio is None:
return stream_text, debug_text
if isinstance(audio, tuple):
sr, arr = audio
arr = np.array(arr)
if arr.dtype != np.float32:
arr = arr.astype(np.float32)
if arr.max() > 1.0 or arr.min() < -1.0:
arr /= 32768.0
if sr != 16000:
arr = librosa.resample(arr, orig_sr=sr, target_sr=16000)
else:
arr = np.array(audio, dtype=np.float32)
if use_vac:
vac_online.insert_audio_chunk(arr)
beg, end, txt = vac_online.process_iter()
log_debug(f"[VAC] Insert {len(arr)} samples | Output: {txt}")
else:
online.insert_audio_chunk(arr)
beg, end, txt = online.process_iter()
log_debug(f"[Online] Insert {len(arr)} samples | Output: {txt}")
if txt:
stream_text += txt + "\n"
log_debug(f"[Flush] {beg}-{end} | '{txt}'")
return stream_text, debug_text
def clear_text():
global stream_text, debug_text
stream_text = ""
debug_text = ""
return stream_text, debug_text
# -------------------------------
# Gradio UI
# -------------------------------
with gr.Blocks(title="Cantonese Streaming (VAC)", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎤 Cantonese Streaming Transcription with VAC + Debug Logs")
gr.Markdown("✅ 支援 VAC,並可在下方調整靜音閾值與語音閾值")
with gr.Row():
with gr.Column(scale=1):
vac_mode = gr.Checkbox(label="啟用 VAC 模式", value=False)
silence_slider = gr.Slider(label="靜音閾值 (秒)", minimum=0.3, maximum=1.2, value=0.8, step=0.1)
threshold_slider = gr.Slider(label="語音閾值", minimum=0.1, maximum=0.9, value=0.5, step=0.05)
start_btn = gr.Button("🔴 Start")
stop_btn = gr.Button("⏹️ Stop", interactive=False)
clear_btn = gr.Button("🗑️ Clear")
with gr.Column(scale=2):
mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="🎙️ Live Input")
output = gr.Textbox(label="📝 Transcript", lines=15, autoscroll=True)
debug_output = gr.Textbox(label="🔎 Debug Window", lines=15, autoscroll=True)
start_btn.click(start_transcription, inputs=[vac_mode, silence_slider, threshold_slider],
outputs=[output, start_btn, stop_btn, debug_output])
stop_btn.click(stop_transcription, outputs=[output, start_btn, stop_btn, output, debug_output])
clear_btn.click(clear_text, outputs=[output, debug_output])
mic.stream(process_stream, inputs=[mic], outputs=[output, debug_output], stream_every=0.5)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0",
server_port=7860,
share=False,
ssr_mode=False) # 關閉 SSR