from fastapi import FastAPI, HTTPException from pydantic import BaseModel import base64 import tempfile import os import json from typing import Optional import logging import time import asyncio # 设置缓存目录 os.environ['XDG_CACHE_HOME'] = '/app/.cache' # 确保缓存目录存在 os.makedirs('/app/.cache', exist_ok=True) # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Whisper API", version="1.0.0") class AudioRequest(BaseModel): audio: str # base64 编码的音频数据 model: str = "tiny" # 默认使用tiny模型以提高速度 language: Optional[str] = "zh" # 默认中文 temperature: Optional[float] = 0.0 beam_size: Optional[int] = 1 fast_mode: Optional[bool] = True # 快速模式 vad: Optional[bool] = False threads: Optional[int] = 4 atempo: Optional[float] = 1.0 def load_model(model_name: str): """确保模型文件存在,返回模型路径""" # 检查多个可能的模型路径 possible_paths = [ f"/app/models/ggml-{model_name}.bin", f"/app/models/for-tests-ggml-{model_name}.bin" ] # 检查是否有任何一个路径存在 for path in possible_paths: if os.path.exists(path): logger.info(f"找到模型: {path}") return path # 如果没有找到,使用测试模型 test_models = [ "/app/models/for-tests-ggml-base.bin", "/app/models/ggml-base.en.bin", "/app/models/for-tests-ggml-tiny.bin" ] for test_model in test_models: if os.path.exists(test_model): logger.info(f"使用测试模型: {test_model}") return test_model # 如果连测试模型都没有,报错 logger.error(f"找不到模型 {model_name},请确保模型文件存在") raise HTTPException(status_code=500, detail=f"Model {model_name} not found") async def convert_audio_to_wav(input_file: str, atempo: float = 1.0) -> str: """使用ffmpeg将音频文件转换为WAV格式,支持atempo变速""" try: # 创建输出文件路径 output_file = input_file.rsplit('.', 1)[0] + '_converted.wav' # 构建ffmpeg命令 采样率:16kHz 单声道 音频编码器:16位PCM cmd = [ "ffmpeg", "-i", input_file, "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", ] if atempo != 1.0: cmd += ["-filter:a", f"atempo={atempo}"] cmd += [ "-y", output_file ] logger.info(f"开始音频转换: {' '.join(cmd)}") # 执行ffmpeg命令 proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() if proc.returncode != 0: error_msg = stderr.decode() if stderr else "Unknown ffmpeg error" logger.error(f"音频转换失败: {error_msg}") raise HTTPException(status_code=500, detail=f"Audio conversion failed: {error_msg}") # 验证输出文件是否存在 if not os.path.exists(output_file): raise HTTPException(status_code=500, detail="Converted audio file not found") # 删除原始文件 if os.path.exists(input_file): os.unlink(input_file) # 采样率16kHz,单声道,16位=2字节 file_size = os.path.getsize(output_file) duration_sec = file_size / (16000 * 2 * 1) # 采样率*字节数*声道数 logger.info(f"音频转换成功: {output_file}, 大小: {file_size} 字节, 时长: {duration_sec:.2f} 秒") return output_file except HTTPException: raise except Exception as e: logger.error(f"音频转换过程中出错: {e}") raise HTTPException(status_code=500, detail=f"Audio conversion error: {str(e)}") def decode_audio(audio_base64: str) -> str: """解码base64音频数据并保存为临时文件,返回文件路径""" try: # 移除data URL前缀(如果存在) if "," in audio_base64: parts = audio_base64.split(",", 1) mime_type = parts[0] if len(parts) > 1 else "" audio_base64 = parts[1] if len(parts) > 1 else parts[0] logger.info(f"检测到MIME类型: {mime_type}") # 解码base64 try: audio_data = base64.b64decode(audio_base64) logger.info(f"成功解码音频数据,大小: {len(audio_data)} 字节") except Exception as e: logger.error(f"Base64解码失败: {e}") raise HTTPException(status_code=400, detail=f"Invalid base64 data: {str(e)}") # 检测音频格式 file_extension = ".wav" # 默认 if len(audio_data) >= 12: header = audio_data[:12] if header[:4] == b'RIFF' and header[8:12] == b'WAVE': file_extension = ".wav" logger.info("检测到WAV格式") elif b'ftyp' in header and b'M4A' in header: file_extension = ".m4a" logger.info("检测到M4A格式") elif header[:3] == b'ID3' or header[:2] == b'\xff\xfb': file_extension = ".mp3" logger.info("检测到MP3格式") elif header[:4] == b'OggS': file_extension = ".ogg" logger.info("检测到OGG格式") elif header[:4] == b'fLaC': file_extension = ".flac" logger.info("检测到FLAC格式") else: logger.warning(f"未知音频格式,文件头: {header.hex()}") # 创建临时文件,使用检测到的扩展名 with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension, mode="wb") as temp_file: temp_file.write(audio_data) temp_path = temp_file.name # 确保文件可读 os.chmod(temp_path, 0o644) # 验证文件是否存在且可读 if not os.path.exists(temp_path): raise HTTPException(status_code=500, detail="Failed to create temporary audio file") logger.info(f"音频文件已保存到: {temp_path}, 大小: {os.path.getsize(temp_path)} 字节, 格式: {file_extension}") # 检查格式兼容性 supported_formats = [".wav", ".flac", ".mp3", ".ogg"] if file_extension not in supported_formats: logger.warning(f"音频格式 {file_extension} 可能不被whisper-cli支持,支持的格式: {supported_formats}") return temp_path except HTTPException: raise except Exception as e: logger.error(f"音频解码失败: {str(e)}") raise HTTPException(status_code=400, detail=f"Invalid audio data: {str(e)}") def parse_whisper_output(output_file: str, stdout: bytes, exit_code: int) -> dict: """解析whisper输出文件,如果有JSON则读取,否则返回stdout内容""" json_output_file = output_file + ".json" result = {} if os.path.exists(json_output_file): try: with open(json_output_file, 'r', encoding='utf-8', errors='replace') as f: result = json.loads(f.read()) result["full_text"] = "".join([item["text"] for item in result.get("transcription", [])]) logger.info(f"成功读取JSON输出文件: {json_output_file}") except Exception as e: logger.error(f"读取JSON输出文件失败: {e}") result = {"error": f"Failed to read JSON output: {str(e)}"} else: # 如果没有JSON输出,使用命令行输出 logger.warning(f"未找到JSON输出文件: {json_output_file}") result = { "text": stdout.decode(errors='replace'), # 使用stdout作为文本输出 "status": "completed" if exit_code == 0 else "failed", "exit_code": exit_code } return result def cleanup_temp_files(audio_file, output_file, temp_dir): """清理音频、输出文件和临时目录""" try: # 删除音频文件 if audio_file and os.path.exists(audio_file): os.unlink(audio_file) # 删除转换后的文件(如 _converted.wav) if audio_file and audio_file.endswith('_converted.wav'): original_file = audio_file.replace('_converted.wav', '.m4a') if os.path.exists(original_file): os.unlink(original_file) # 删除输出JSON文件 json_output_file = output_file + ".json" if os.path.exists(json_output_file): os.unlink(json_output_file) # 删除临时目录 if temp_dir and os.path.exists(temp_dir): import shutil shutil.rmtree(temp_dir, ignore_errors=True) except Exception as e: logger.warning(f"清理临时文件时出错: {e}") @app.post("/transcribe") async def transcribe_audio(request: AudioRequest): """音频转录API,异步调用 whisper.cpp 并返回转录结果""" try: logger.info(f"收到转录请求: 模型={request.model}, 语言={request.language}") # 解码音频并保存为临时文件 audio_file = decode_audio(request.audio) # 获取模型路径 model_path = load_model(request.model) logger.info(f"使用模型: {model_path}") # 检查whisper.cpp二进制路径 whisper_binary = "/app/build/bin/whisper-cli" logger.info(f"使用whisper二进制: {whisper_binary}") # 检查音频格式,如果不支持则转换为WAV supported_formats = ('.wav', '.flac', '.mp3', '.ogg') if not audio_file.endswith(supported_formats): logger.info(f"音频格式不直接支持,将转换为WAV: {audio_file}") audio_file = await convert_audio_to_wav(audio_file, request.atempo) # 创建临时目录用于输出 temp_dir = tempfile.mkdtemp() output_file = os.path.join(temp_dir, "output") # 构建命令 - 根据fast_mode调整参数 if request.fast_mode: # 快速模式:牺牲一些精度换取速度 cmd = [ whisper_binary, "-m", model_path, "-f", audio_file, "-l", request.language or "auto", "-oj", # --output-json: 输出JSON格式 "-of", output_file, # 指定输出文件 "-t", str(request.threads), # 使用所有CPU核心 "-bs", "1", # beam size = 1 (最快) beam search "-bo", "1", # best of = 1 (最快) greedy "-ac", "0", # 音频上下文 = 0 (最快) "-nf", # --no-fallback: 禁用温度回退 "-nt", # 不打印timestamp "--vad" if request.vad else "", "-vm", "/app/models/ggml-silero-v5.1.2.bin" if request.vad else "" ] else: # 标准模式:平衡速度和精度 cmd = [ whisper_binary, "-m", model_path, "-f", audio_file, "-l", request.language or "auto", "-oj", # --output-json: 输出JSON格式 "-of", output_file, # 指定输出文件 "-t", str(request.threads), # 使用所有CPU核心 "-bs", "5", # beam size = 5 (默认) "-bo", "5", # best of = 5 (默认) ] # 添加可选参数(覆盖默认值) if request.beam_size and request.beam_size != 1: # 移除默认的-bs 1,添加用户指定的值 if "-bs" in cmd and "1" in cmd: bs_index = cmd.index("-bs") if bs_index + 1 < len(cmd) and cmd[bs_index + 1] == "1": cmd[bs_index + 1] = str(request.beam_size) if request.temperature: cmd += ["-tp", str(request.temperature)] # --temperature 的简写 try: # 执行命令 start_time = time.time() logger.info(f"开始执行命令: {' '.join(cmd)}") proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, ) logger.info("whisper子进程已创建,开始等待输出") # 设置超时时间避免无限等待 stdout, _ = await asyncio.wait_for( proc.communicate(), timeout=300 # 5分钟超时 ) # logger.info("whisper子进程输出已获取") # 安全的编码解码 try: output_text = stdout.decode('utf-8') except UnicodeDecodeError: # 如果UTF-8解码失败,尝试其他编码 output_text = stdout.decode('utf-8', errors='replace') logger.warning("输出包含非UTF-8字符,已替换") # 记录输出日志 # for line in output_text.splitlines(): # if line.strip(): # logger.info(f"whisper输出: {line.strip()}") # 检查退出码 exit_code = proc.returncode processing_time = time.time() - start_time logger.info(f"命令执行完成,退出码: {exit_code},处理时间: {processing_time:.2f}秒") # 读取JSON输出文件 result = parse_whisper_output(output_file, stdout, exit_code) result["processing_time"] = f"{processing_time:.2f}" result["cmd"] = " ".join(cmd) return result except asyncio.TimeoutError: logger.error("命令执行超时") if proc: proc.kill() await proc.wait() raise HTTPException(status_code=500, detail="Command execution timed out") except Exception as e: logger.error(f"处理过程中出错: {e}") if proc: proc.kill() await proc.wait() raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") finally: # 清理临时文件 cleanup_temp_files(audio_file, output_file, temp_dir) except Exception as e: logger.error(f"转录失败: {e}") raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") @app.get("/health") async def health_check(): """健康检查""" try: # 检查whisper.cpp二进制是否存在 whisper_binary = "/app/build/bin/whisper-cli" binary_exists = os.path.exists(whisper_binary) # 检查模型目录 model_dirs = ["/app/models", "/models"] model_files = [] for dir_path in model_dirs: if os.path.exists(dir_path): try: model_files.extend([f"{dir_path}/{f}" for f in os.listdir(dir_path) if f.endswith(".bin")]) except: pass return { "status": "healthy", "whisper_binary": whisper_binary, "binary_exists": binary_exists, "model_dirs": {dir_path: os.path.exists(dir_path) for dir_path in model_dirs}, "available_models": model_files } except Exception as e: return { "status": "error", "error": str(e) } @app.get("/") async def root(): """根路径""" return { "message": "Whisper API is running", "version": "1.0.0", "endpoints": { "health": "/health", "transcribe": "/transcribe" } } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)