import gradio as gr from transformers import AutoModelForCTC, AutoFeatureExtractor, AutoTokenizer import torch import numpy as np import warnings import librosa warnings.filterwarnings("ignore") MODEL_ID = "google/medasr" model = None feature_extractor = None tokenizer = None def normalize_audio(audio): """RMS归一化""" rms = np.sqrt(np.mean(audio ** 2)) if rms > 0: audio = audio / rms audio = np.clip(audio, -1.0, 1.0) return audio def remove_silence(audio, sample_rate, threshold=0.01): """去除静音段""" energy = np.abs(audio) above_threshold = energy > threshold if not np.any(above_threshold): return audio start = np.where(above_threshold)[0][0] end = np.where(above_threshold)[0][-1] buffer = int(0.1 * sample_rate) start = max(0, start - buffer) end = min(len(audio), end + buffer) return audio[start:end] def load_model_with_token(hf_token): global model, feature_extractor, tokenizer if not hf_token or not hf_token.strip(): return gr.update(interactive=False, value="❌ Token cannot be empty!"), gr.update(interactive=False) try: device = "cuda" if torch.cuda.is_available() else "cpu" print("🔄 Loading model components...") print(f"📱 Device: {device}") feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID, token=hf_token.strip()) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token.strip()) model = AutoModelForCTC.from_pretrained(MODEL_ID, token=hf_token.strip()).to(device) model.eval() print(f"✅ Loaded: {type(feature_extractor)}, {type(tokenizer)}, {type(model)}") return gr.update(interactive=False, value="✅ Model Loaded Successfully!"), gr.update(interactive=True) except Exception as e: print(f"Error loading model: {e}") import traceback traceback.print_exc() return gr.update(interactive=True, value=f"❌ Error: {str(e)}"), gr.update(interactive=False) def transcribe_audio(audio_input): global model, feature_extractor, tokenizer if audio_input is None: return "⚠️ Please upload or record audio." if model is None or feature_extractor is None: return "❌ Please load the model first!" try: # 1. 解包音频 sample_rate, waveform = audio_input # 2. 转单声道 if waveform.ndim == 2: waveform = waveform[:, 0] # 3. 转换为 float32 并归一化 if waveform.dtype == np.int16: waveform = waveform.astype(np.float32) / 32768.0 elif waveform.dtype != np.float32: waveform = waveform.astype(np.float32) # 4. RMS归一化 waveform = normalize_audio(waveform) # 5. 去除静音 waveform = remove_silence(waveform, sample_rate) # 6. 检查长度 duration = len(waveform) / sample_rate if duration < 0.1: return "⚠️ Audio is too short." if duration > 60: return "⚠️ Audio is too long." # 7. 重采样 if sample_rate != 16000: waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) sample_rate = 16000 # 8. 特征提取 inputs = feature_extractor( waveform, sampling_rate=sample_rate, return_tensors="pt", ) inputs = {k: v.to(model.device) for k, v in inputs.items()} # 【修复部分】 # 自动查找包含特征数据的键(可能是 'input_features' 或 'input_values') # 过滤掉 'attention_mask',找到真正的输入 Tensor input_tensor = None for key, val in inputs.items(): if isinstance(val, torch.Tensor) and val.ndim > 1: input_tensor = val break if input_tensor is None: return "❌ Error: Could not extract audio features." # 安全地获取 stride,如果不存在则默认为 4 stride = 4 if hasattr(feature_extractor, 'stride'): s = feature_extractor.stride stride = s[0] if isinstance(s, (list, tuple)) else s # 动态计算 max_length max_length = input_tensor.shape[1] // stride + 50 # 9. Beam search 解码 with torch.no_grad(): outputs = model.generate( **inputs, max_length=max_length, num_beams=8, # Beam search 提升准确率 temperature=1.0, ) # 10. 解码 transcription = tokenizer.batch_decode(outputs.tolist(), skip_special_tokens=True)[0] # 11. 后处理 transcription = transcription.strip() import re transcription = re.sub(r'\s+', ' ', transcription) return transcription if transcription else "⚠️ No speech detected." except Exception as e: import traceback traceback.print_exc() return f"❌ Transcription error: {str(e)}" with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🏥 MedASR - Medical Speech Recognition") gr.Markdown("Optimized for medical dictation with Beam Search decoding.") gr.Markdown("---") with gr.Row(): with gr.Column(scale=1): gr.Markdown("## 🔑 Authentication") hf_token = gr.Textbox(label="HuggingFace Token", type="password", placeholder="hf_...") load_model_btn = gr.Button("🔑 Load Model", variant="primary", size="lg") gr.Markdown("## 📝 Tips") gr.Markdown(""" - Speak **clearly and slowly** - Use **medical terms** - Short audio (3-10s) is best - Quiet environment """) with gr.Column(scale=2): gr.Markdown("## 🎙️ Input & Result") audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy") with gr.Row(): transcribe_btn = gr.Button("🔄 Transcribe", variant="secondary", size="lg", interactive=False) clear_btn = gr.Button("🗑️ Clear", variant="stop", size="lg") output_text = gr.Textbox(label="Result", lines=12, placeholder="...") audio_info = gr.Textbox(label="Info", lines=2, interactive=False) def transcribe_wrapper(audio_in): res = transcribe_audio(audio_in) info = f"Status: Success" if "❌" not in res and "⚠️" not in res else "Status: Check result" return res, info load_model_btn.click( fn=load_model_with_token, inputs=[hf_token], outputs=[load_model_btn, transcribe_btn] ) transcribe_btn.click( fn=transcribe_wrapper, inputs=[audio_input], outputs=[output_text, audio_info] ) clear_btn.click( fn=lambda: ("", "Ready"), outputs=[output_text, audio_info] ) if __name__ == "__main__": demo.launch()