whisper / fixed_app.py
1een's picture
93
8bbc7d1
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import base64
import tempfile
import os
import json
from typing import Optional
import logging
import time
import asyncio
# 设置缓存目录
os.environ['XDG_CACHE_HOME'] = '/app/.cache'
# 确保缓存目录存在
os.makedirs('/app/.cache', exist_ok=True)
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Whisper API", version="1.0.0")
class AudioRequest(BaseModel):
audio: str # base64 编码的音频数据
model: str = "tiny" # 默认使用tiny模型以提高速度
language: Optional[str] = "zh" # 默认中文
temperature: Optional[float] = 0.0
beam_size: Optional[int] = 1
fast_mode: Optional[bool] = True # 快速模式
vad: Optional[bool] = False
threads: Optional[int] = 4
atempo: Optional[float] = 1.0
def load_model(model_name: str):
"""确保模型文件存在,返回模型路径"""
# 检查多个可能的模型路径
possible_paths = [
f"/app/models/ggml-{model_name}.bin",
f"/app/models/for-tests-ggml-{model_name}.bin"
]
# 检查是否有任何一个路径存在
for path in possible_paths:
if os.path.exists(path):
logger.info(f"找到模型: {path}")
return path
# 如果没有找到,使用测试模型
test_models = [
"/app/models/for-tests-ggml-base.bin",
"/app/models/ggml-base.en.bin",
"/app/models/for-tests-ggml-tiny.bin"
]
for test_model in test_models:
if os.path.exists(test_model):
logger.info(f"使用测试模型: {test_model}")
return test_model
# 如果连测试模型都没有,报错
logger.error(f"找不到模型 {model_name},请确保模型文件存在")
raise HTTPException(status_code=500, detail=f"Model {model_name} not found")
async def convert_audio_to_wav(input_file: str, atempo: float = 1.0) -> str:
"""使用ffmpeg将音频文件转换为WAV格式,支持atempo变速"""
try:
# 创建输出文件路径
output_file = input_file.rsplit('.', 1)[0] + '_converted.wav'
# 构建ffmpeg命令 采样率:16kHz 单声道 音频编码器:16位PCM
cmd = [
"ffmpeg",
"-i", input_file,
"-ar", "16000",
"-ac", "1",
"-c:a", "pcm_s16le",
]
if atempo != 1.0:
cmd += ["-filter:a", f"atempo={atempo}"]
cmd += [
"-y",
output_file
]
logger.info(f"开始音频转换: {' '.join(cmd)}")
# 执行ffmpeg命令
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
error_msg = stderr.decode() if stderr else "Unknown ffmpeg error"
logger.error(f"音频转换失败: {error_msg}")
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {error_msg}")
# 验证输出文件是否存在
if not os.path.exists(output_file):
raise HTTPException(status_code=500, detail="Converted audio file not found")
# 删除原始文件
if os.path.exists(input_file):
os.unlink(input_file)
# 采样率16kHz,单声道,16位=2字节
file_size = os.path.getsize(output_file)
duration_sec = file_size / (16000 * 2 * 1) # 采样率*字节数*声道数
logger.info(f"音频转换成功: {output_file}, 大小: {file_size} 字节, 时长: {duration_sec:.2f} 秒")
return output_file
except HTTPException:
raise
except Exception as e:
logger.error(f"音频转换过程中出错: {e}")
raise HTTPException(status_code=500, detail=f"Audio conversion error: {str(e)}")
def decode_audio(audio_base64: str) -> str:
"""解码base64音频数据并保存为临时文件,返回文件路径"""
try:
# 移除data URL前缀(如果存在)
if "," in audio_base64:
parts = audio_base64.split(",", 1)
mime_type = parts[0] if len(parts) > 1 else ""
audio_base64 = parts[1] if len(parts) > 1 else parts[0]
logger.info(f"检测到MIME类型: {mime_type}")
# 解码base64
try:
audio_data = base64.b64decode(audio_base64)
logger.info(f"成功解码音频数据,大小: {len(audio_data)} 字节")
except Exception as e:
logger.error(f"Base64解码失败: {e}")
raise HTTPException(status_code=400, detail=f"Invalid base64 data: {str(e)}")
# 检测音频格式
file_extension = ".wav" # 默认
if len(audio_data) >= 12:
header = audio_data[:12]
if header[:4] == b'RIFF' and header[8:12] == b'WAVE':
file_extension = ".wav"
logger.info("检测到WAV格式")
elif b'ftyp' in header and b'M4A' in header:
file_extension = ".m4a"
logger.info("检测到M4A格式")
elif header[:3] == b'ID3' or header[:2] == b'\xff\xfb':
file_extension = ".mp3"
logger.info("检测到MP3格式")
elif header[:4] == b'OggS':
file_extension = ".ogg"
logger.info("检测到OGG格式")
elif header[:4] == b'fLaC':
file_extension = ".flac"
logger.info("检测到FLAC格式")
else:
logger.warning(f"未知音频格式,文件头: {header.hex()}")
# 创建临时文件,使用检测到的扩展名
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension, mode="wb") as temp_file:
temp_file.write(audio_data)
temp_path = temp_file.name
# 确保文件可读
os.chmod(temp_path, 0o644)
# 验证文件是否存在且可读
if not os.path.exists(temp_path):
raise HTTPException(status_code=500, detail="Failed to create temporary audio file")
logger.info(f"音频文件已保存到: {temp_path}, 大小: {os.path.getsize(temp_path)} 字节, 格式: {file_extension}")
# 检查格式兼容性
supported_formats = [".wav", ".flac", ".mp3", ".ogg"]
if file_extension not in supported_formats:
logger.warning(f"音频格式 {file_extension} 可能不被whisper-cli支持,支持的格式: {supported_formats}")
return temp_path
except HTTPException:
raise
except Exception as e:
logger.error(f"音频解码失败: {str(e)}")
raise HTTPException(status_code=400, detail=f"Invalid audio data: {str(e)}")
def parse_whisper_output(output_file: str, stdout: bytes, exit_code: int) -> dict:
"""解析whisper输出文件,如果有JSON则读取,否则返回stdout内容"""
json_output_file = output_file + ".json"
result = {}
if os.path.exists(json_output_file):
try:
with open(json_output_file, 'r', encoding='utf-8', errors='replace') as f:
result = json.loads(f.read())
result["full_text"] = "".join([item["text"] for item in result.get("transcription", [])])
logger.info(f"成功读取JSON输出文件: {json_output_file}")
except Exception as e:
logger.error(f"读取JSON输出文件失败: {e}")
result = {"error": f"Failed to read JSON output: {str(e)}"}
else:
# 如果没有JSON输出,使用命令行输出
logger.warning(f"未找到JSON输出文件: {json_output_file}")
result = {
"text": stdout.decode(errors='replace'), # 使用stdout作为文本输出
"status": "completed" if exit_code == 0 else "failed",
"exit_code": exit_code
}
return result
def cleanup_temp_files(audio_file, output_file, temp_dir):
"""清理音频、输出文件和临时目录"""
try:
# 删除音频文件
if audio_file and os.path.exists(audio_file):
os.unlink(audio_file)
# 删除转换后的文件(如 _converted.wav)
if audio_file and audio_file.endswith('_converted.wav'):
original_file = audio_file.replace('_converted.wav', '.m4a')
if os.path.exists(original_file):
os.unlink(original_file)
# 删除输出JSON文件
json_output_file = output_file + ".json"
if os.path.exists(json_output_file):
os.unlink(json_output_file)
# 删除临时目录
if temp_dir and os.path.exists(temp_dir):
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception as e:
logger.warning(f"清理临时文件时出错: {e}")
@app.post("/transcribe")
async def transcribe_audio(request: AudioRequest):
"""音频转录API,异步调用 whisper.cpp 并返回转录结果"""
try:
logger.info(f"收到转录请求: 模型={request.model}, 语言={request.language}")
# 解码音频并保存为临时文件
audio_file = decode_audio(request.audio)
# 获取模型路径
model_path = load_model(request.model)
logger.info(f"使用模型: {model_path}")
# 检查whisper.cpp二进制路径
whisper_binary = "/app/build/bin/whisper-cli"
logger.info(f"使用whisper二进制: {whisper_binary}")
# 检查音频格式,如果不支持则转换为WAV
supported_formats = ('.wav', '.flac', '.mp3', '.ogg')
if not audio_file.endswith(supported_formats):
logger.info(f"音频格式不直接支持,将转换为WAV: {audio_file}")
audio_file = await convert_audio_to_wav(audio_file, request.atempo)
# 创建临时目录用于输出
temp_dir = tempfile.mkdtemp()
output_file = os.path.join(temp_dir, "output")
# 构建命令 - 根据fast_mode调整参数
if request.fast_mode:
# 快速模式:牺牲一些精度换取速度
cmd = [
whisper_binary,
"-m", model_path,
"-f", audio_file,
"-l", request.language or "auto",
"-oj", # --output-json: 输出JSON格式
"-of", output_file, # 指定输出文件
"-t", str(request.threads), # 使用所有CPU核心
"-bs", "1", # beam size = 1 (最快) beam search
"-bo", "1", # best of = 1 (最快) greedy
"-ac", "0", # 音频上下文 = 0 (最快)
"-nf", # --no-fallback: 禁用温度回退
"-nt", # 不打印timestamp
"--vad" if request.vad else "",
"-vm", "/app/models/ggml-silero-v5.1.2.bin" if request.vad else ""
]
else:
# 标准模式:平衡速度和精度
cmd = [
whisper_binary,
"-m", model_path,
"-f", audio_file,
"-l", request.language or "auto",
"-oj", # --output-json: 输出JSON格式
"-of", output_file, # 指定输出文件
"-t", str(request.threads), # 使用所有CPU核心
"-bs", "5", # beam size = 5 (默认)
"-bo", "5", # best of = 5 (默认)
]
# 添加可选参数(覆盖默认值)
if request.beam_size and request.beam_size != 1:
# 移除默认的-bs 1,添加用户指定的值
if "-bs" in cmd and "1" in cmd:
bs_index = cmd.index("-bs")
if bs_index + 1 < len(cmd) and cmd[bs_index + 1] == "1":
cmd[bs_index + 1] = str(request.beam_size)
if request.temperature:
cmd += ["-tp", str(request.temperature)] # --temperature 的简写
try:
# 执行命令
start_time = time.time()
logger.info(f"开始执行命令: {' '.join(cmd)}")
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
logger.info("whisper子进程已创建,开始等待输出")
# 设置超时时间避免无限等待
stdout, _ = await asyncio.wait_for(
proc.communicate(),
timeout=300 # 5分钟超时
)
# logger.info("whisper子进程输出已获取")
# 安全的编码解码
try:
output_text = stdout.decode('utf-8')
except UnicodeDecodeError:
# 如果UTF-8解码失败,尝试其他编码
output_text = stdout.decode('utf-8', errors='replace')
logger.warning("输出包含非UTF-8字符,已替换")
# 记录输出日志
# for line in output_text.splitlines():
# if line.strip():
# logger.info(f"whisper输出: {line.strip()}")
# 检查退出码
exit_code = proc.returncode
processing_time = time.time() - start_time
logger.info(f"命令执行完成,退出码: {exit_code},处理时间: {processing_time:.2f}秒")
# 读取JSON输出文件
result = parse_whisper_output(output_file, stdout, exit_code)
result["processing_time"] = f"{processing_time:.2f}"
result["cmd"] = " ".join(cmd)
return result
except asyncio.TimeoutError:
logger.error("命令执行超时")
if proc:
proc.kill()
await proc.wait()
raise HTTPException(status_code=500, detail="Command execution timed out")
except Exception as e:
logger.error(f"处理过程中出错: {e}")
if proc:
proc.kill()
await proc.wait()
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
finally:
# 清理临时文件
cleanup_temp_files(audio_file, output_file, temp_dir)
except Exception as e:
logger.error(f"转录失败: {e}")
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
@app.get("/health")
async def health_check():
"""健康检查"""
try:
# 检查whisper.cpp二进制是否存在
whisper_binary = "/app/build/bin/whisper-cli"
binary_exists = os.path.exists(whisper_binary)
# 检查模型目录
model_dirs = ["/app/models", "/models"]
model_files = []
for dir_path in model_dirs:
if os.path.exists(dir_path):
try:
model_files.extend([f"{dir_path}/{f}" for f in os.listdir(dir_path) if f.endswith(".bin")])
except:
pass
return {
"status": "healthy",
"whisper_binary": whisper_binary,
"binary_exists": binary_exists,
"model_dirs": {dir_path: os.path.exists(dir_path) for dir_path in model_dirs},
"available_models": model_files
}
except Exception as e:
return {
"status": "error",
"error": str(e)
}
@app.get("/")
async def root():
"""根路径"""
return {
"message": "Whisper API is running",
"version": "1.0.0",
"endpoints": {
"health": "/health",
"transcribe": "/transcribe"
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)