|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import base64 |
|
|
import tempfile |
|
|
import os |
|
|
import json |
|
|
from typing import Optional |
|
|
import logging |
|
|
import time |
|
|
import asyncio |
|
|
|
|
|
|
|
|
os.environ['XDG_CACHE_HOME'] = '/app/.cache' |
|
|
|
|
|
|
|
|
os.makedirs('/app/.cache', exist_ok=True) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI(title="Whisper API", version="1.0.0") |
|
|
|
|
|
class AudioRequest(BaseModel): |
|
|
audio: str |
|
|
model: str = "tiny" |
|
|
language: Optional[str] = "zh" |
|
|
temperature: Optional[float] = 0.0 |
|
|
beam_size: Optional[int] = 1 |
|
|
fast_mode: Optional[bool] = True |
|
|
vad: Optional[bool] = False |
|
|
threads: Optional[int] = 4 |
|
|
atempo: Optional[float] = 1.0 |
|
|
|
|
|
def load_model(model_name: str): |
|
|
"""确保模型文件存在,返回模型路径""" |
|
|
|
|
|
possible_paths = [ |
|
|
f"/app/models/ggml-{model_name}.bin", |
|
|
f"/app/models/for-tests-ggml-{model_name}.bin" |
|
|
] |
|
|
|
|
|
|
|
|
for path in possible_paths: |
|
|
if os.path.exists(path): |
|
|
logger.info(f"找到模型: {path}") |
|
|
return path |
|
|
|
|
|
|
|
|
test_models = [ |
|
|
"/app/models/for-tests-ggml-base.bin", |
|
|
"/app/models/ggml-base.en.bin", |
|
|
"/app/models/for-tests-ggml-tiny.bin" |
|
|
] |
|
|
|
|
|
for test_model in test_models: |
|
|
if os.path.exists(test_model): |
|
|
logger.info(f"使用测试模型: {test_model}") |
|
|
return test_model |
|
|
|
|
|
|
|
|
logger.error(f"找不到模型 {model_name},请确保模型文件存在") |
|
|
raise HTTPException(status_code=500, detail=f"Model {model_name} not found") |
|
|
|
|
|
async def convert_audio_to_wav(input_file: str, atempo: float = 1.0) -> str: |
|
|
"""使用ffmpeg将音频文件转换为WAV格式,支持atempo变速""" |
|
|
try: |
|
|
|
|
|
output_file = input_file.rsplit('.', 1)[0] + '_converted.wav' |
|
|
|
|
|
|
|
|
cmd = [ |
|
|
"ffmpeg", |
|
|
"-i", input_file, |
|
|
"-ar", "16000", |
|
|
"-ac", "1", |
|
|
"-c:a", "pcm_s16le", |
|
|
] |
|
|
if atempo != 1.0: |
|
|
cmd += ["-filter:a", f"atempo={atempo}"] |
|
|
cmd += [ |
|
|
"-y", |
|
|
output_file |
|
|
] |
|
|
|
|
|
logger.info(f"开始音频转换: {' '.join(cmd)}") |
|
|
|
|
|
|
|
|
proc = await asyncio.create_subprocess_exec( |
|
|
*cmd, |
|
|
stdout=asyncio.subprocess.PIPE, |
|
|
stderr=asyncio.subprocess.PIPE |
|
|
) |
|
|
|
|
|
stdout, stderr = await proc.communicate() |
|
|
|
|
|
if proc.returncode != 0: |
|
|
error_msg = stderr.decode() if stderr else "Unknown ffmpeg error" |
|
|
logger.error(f"音频转换失败: {error_msg}") |
|
|
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {error_msg}") |
|
|
|
|
|
|
|
|
if not os.path.exists(output_file): |
|
|
raise HTTPException(status_code=500, detail="Converted audio file not found") |
|
|
|
|
|
|
|
|
if os.path.exists(input_file): |
|
|
os.unlink(input_file) |
|
|
|
|
|
|
|
|
file_size = os.path.getsize(output_file) |
|
|
duration_sec = file_size / (16000 * 2 * 1) |
|
|
logger.info(f"音频转换成功: {output_file}, 大小: {file_size} 字节, 时长: {duration_sec:.2f} 秒") |
|
|
return output_file |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"音频转换过程中出错: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Audio conversion error: {str(e)}") |
|
|
|
|
|
def decode_audio(audio_base64: str) -> str: |
|
|
"""解码base64音频数据并保存为临时文件,返回文件路径""" |
|
|
try: |
|
|
|
|
|
if "," in audio_base64: |
|
|
parts = audio_base64.split(",", 1) |
|
|
mime_type = parts[0] if len(parts) > 1 else "" |
|
|
audio_base64 = parts[1] if len(parts) > 1 else parts[0] |
|
|
|
|
|
logger.info(f"检测到MIME类型: {mime_type}") |
|
|
|
|
|
|
|
|
try: |
|
|
audio_data = base64.b64decode(audio_base64) |
|
|
logger.info(f"成功解码音频数据,大小: {len(audio_data)} 字节") |
|
|
except Exception as e: |
|
|
logger.error(f"Base64解码失败: {e}") |
|
|
raise HTTPException(status_code=400, detail=f"Invalid base64 data: {str(e)}") |
|
|
|
|
|
|
|
|
file_extension = ".wav" |
|
|
if len(audio_data) >= 12: |
|
|
header = audio_data[:12] |
|
|
if header[:4] == b'RIFF' and header[8:12] == b'WAVE': |
|
|
file_extension = ".wav" |
|
|
logger.info("检测到WAV格式") |
|
|
elif b'ftyp' in header and b'M4A' in header: |
|
|
file_extension = ".m4a" |
|
|
logger.info("检测到M4A格式") |
|
|
elif header[:3] == b'ID3' or header[:2] == b'\xff\xfb': |
|
|
file_extension = ".mp3" |
|
|
logger.info("检测到MP3格式") |
|
|
elif header[:4] == b'OggS': |
|
|
file_extension = ".ogg" |
|
|
logger.info("检测到OGG格式") |
|
|
elif header[:4] == b'fLaC': |
|
|
file_extension = ".flac" |
|
|
logger.info("检测到FLAC格式") |
|
|
else: |
|
|
logger.warning(f"未知音频格式,文件头: {header.hex()}") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension, mode="wb") as temp_file: |
|
|
temp_file.write(audio_data) |
|
|
temp_path = temp_file.name |
|
|
|
|
|
|
|
|
os.chmod(temp_path, 0o644) |
|
|
|
|
|
|
|
|
if not os.path.exists(temp_path): |
|
|
raise HTTPException(status_code=500, detail="Failed to create temporary audio file") |
|
|
|
|
|
logger.info(f"音频文件已保存到: {temp_path}, 大小: {os.path.getsize(temp_path)} 字节, 格式: {file_extension}") |
|
|
|
|
|
|
|
|
supported_formats = [".wav", ".flac", ".mp3", ".ogg"] |
|
|
if file_extension not in supported_formats: |
|
|
logger.warning(f"音频格式 {file_extension} 可能不被whisper-cli支持,支持的格式: {supported_formats}") |
|
|
|
|
|
return temp_path |
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"音频解码失败: {str(e)}") |
|
|
raise HTTPException(status_code=400, detail=f"Invalid audio data: {str(e)}") |
|
|
|
|
|
def parse_whisper_output(output_file: str, stdout: bytes, exit_code: int) -> dict: |
|
|
"""解析whisper输出文件,如果有JSON则读取,否则返回stdout内容""" |
|
|
json_output_file = output_file + ".json" |
|
|
result = {} |
|
|
if os.path.exists(json_output_file): |
|
|
try: |
|
|
with open(json_output_file, 'r', encoding='utf-8', errors='replace') as f: |
|
|
result = json.loads(f.read()) |
|
|
result["full_text"] = "".join([item["text"] for item in result.get("transcription", [])]) |
|
|
logger.info(f"成功读取JSON输出文件: {json_output_file}") |
|
|
except Exception as e: |
|
|
logger.error(f"读取JSON输出文件失败: {e}") |
|
|
result = {"error": f"Failed to read JSON output: {str(e)}"} |
|
|
else: |
|
|
|
|
|
logger.warning(f"未找到JSON输出文件: {json_output_file}") |
|
|
result = { |
|
|
"text": stdout.decode(errors='replace'), |
|
|
"status": "completed" if exit_code == 0 else "failed", |
|
|
"exit_code": exit_code |
|
|
} |
|
|
return result |
|
|
|
|
|
def cleanup_temp_files(audio_file, output_file, temp_dir): |
|
|
"""清理音频、输出文件和临时目录""" |
|
|
try: |
|
|
|
|
|
if audio_file and os.path.exists(audio_file): |
|
|
os.unlink(audio_file) |
|
|
|
|
|
if audio_file and audio_file.endswith('_converted.wav'): |
|
|
original_file = audio_file.replace('_converted.wav', '.m4a') |
|
|
if os.path.exists(original_file): |
|
|
os.unlink(original_file) |
|
|
|
|
|
json_output_file = output_file + ".json" |
|
|
if os.path.exists(json_output_file): |
|
|
os.unlink(json_output_file) |
|
|
|
|
|
if temp_dir and os.path.exists(temp_dir): |
|
|
import shutil |
|
|
shutil.rmtree(temp_dir, ignore_errors=True) |
|
|
except Exception as e: |
|
|
logger.warning(f"清理临时文件时出错: {e}") |
|
|
|
|
|
@app.post("/transcribe") |
|
|
async def transcribe_audio(request: AudioRequest): |
|
|
"""音频转录API,异步调用 whisper.cpp 并返回转录结果""" |
|
|
try: |
|
|
logger.info(f"收到转录请求: 模型={request.model}, 语言={request.language}") |
|
|
|
|
|
|
|
|
audio_file = decode_audio(request.audio) |
|
|
|
|
|
|
|
|
model_path = load_model(request.model) |
|
|
logger.info(f"使用模型: {model_path}") |
|
|
|
|
|
|
|
|
whisper_binary = "/app/build/bin/whisper-cli" |
|
|
logger.info(f"使用whisper二进制: {whisper_binary}") |
|
|
|
|
|
|
|
|
supported_formats = ('.wav', '.flac', '.mp3', '.ogg') |
|
|
if not audio_file.endswith(supported_formats): |
|
|
logger.info(f"音频格式不直接支持,将转换为WAV: {audio_file}") |
|
|
audio_file = await convert_audio_to_wav(audio_file, request.atempo) |
|
|
|
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
output_file = os.path.join(temp_dir, "output") |
|
|
|
|
|
|
|
|
if request.fast_mode: |
|
|
|
|
|
cmd = [ |
|
|
whisper_binary, |
|
|
"-m", model_path, |
|
|
"-f", audio_file, |
|
|
"-l", request.language or "auto", |
|
|
"-oj", |
|
|
"-of", output_file, |
|
|
"-t", str(request.threads), |
|
|
"-bs", "1", |
|
|
"-bo", "1", |
|
|
"-ac", "0", |
|
|
"-nf", |
|
|
"-nt", |
|
|
"--vad" if request.vad else "", |
|
|
"-vm", "/app/models/ggml-silero-v5.1.2.bin" if request.vad else "" |
|
|
] |
|
|
else: |
|
|
|
|
|
cmd = [ |
|
|
whisper_binary, |
|
|
"-m", model_path, |
|
|
"-f", audio_file, |
|
|
"-l", request.language or "auto", |
|
|
"-oj", |
|
|
"-of", output_file, |
|
|
"-t", str(request.threads), |
|
|
"-bs", "5", |
|
|
"-bo", "5", |
|
|
] |
|
|
|
|
|
|
|
|
if request.beam_size and request.beam_size != 1: |
|
|
|
|
|
if "-bs" in cmd and "1" in cmd: |
|
|
bs_index = cmd.index("-bs") |
|
|
if bs_index + 1 < len(cmd) and cmd[bs_index + 1] == "1": |
|
|
cmd[bs_index + 1] = str(request.beam_size) |
|
|
if request.temperature: |
|
|
cmd += ["-tp", str(request.temperature)] |
|
|
|
|
|
try: |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"开始执行命令: {' '.join(cmd)}") |
|
|
|
|
|
proc = await asyncio.create_subprocess_exec( |
|
|
*cmd, |
|
|
stdout=asyncio.subprocess.PIPE, |
|
|
stderr=asyncio.subprocess.STDOUT, |
|
|
) |
|
|
logger.info("whisper子进程已创建,开始等待输出") |
|
|
|
|
|
stdout, _ = await asyncio.wait_for( |
|
|
proc.communicate(), |
|
|
timeout=300 |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
output_text = stdout.decode('utf-8') |
|
|
except UnicodeDecodeError: |
|
|
|
|
|
output_text = stdout.decode('utf-8', errors='replace') |
|
|
logger.warning("输出包含非UTF-8字符,已替换") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exit_code = proc.returncode |
|
|
processing_time = time.time() - start_time |
|
|
logger.info(f"命令执行完成,退出码: {exit_code},处理时间: {processing_time:.2f}秒") |
|
|
|
|
|
|
|
|
result = parse_whisper_output(output_file, stdout, exit_code) |
|
|
result["processing_time"] = f"{processing_time:.2f}" |
|
|
result["cmd"] = " ".join(cmd) |
|
|
|
|
|
return result |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
logger.error("命令执行超时") |
|
|
if proc: |
|
|
proc.kill() |
|
|
await proc.wait() |
|
|
raise HTTPException(status_code=500, detail="Command execution timed out") |
|
|
except Exception as e: |
|
|
logger.error(f"处理过程中出错: {e}") |
|
|
if proc: |
|
|
proc.kill() |
|
|
await proc.wait() |
|
|
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") |
|
|
finally: |
|
|
|
|
|
cleanup_temp_files(audio_file, output_file, temp_dir) |
|
|
except Exception as e: |
|
|
logger.error(f"转录失败: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""健康检查""" |
|
|
try: |
|
|
|
|
|
whisper_binary = "/app/build/bin/whisper-cli" |
|
|
binary_exists = os.path.exists(whisper_binary) |
|
|
|
|
|
|
|
|
model_dirs = ["/app/models", "/models"] |
|
|
model_files = [] |
|
|
|
|
|
for dir_path in model_dirs: |
|
|
if os.path.exists(dir_path): |
|
|
try: |
|
|
model_files.extend([f"{dir_path}/{f}" for f in os.listdir(dir_path) if f.endswith(".bin")]) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return { |
|
|
"status": "healthy", |
|
|
"whisper_binary": whisper_binary, |
|
|
"binary_exists": binary_exists, |
|
|
"model_dirs": {dir_path: os.path.exists(dir_path) for dir_path in model_dirs}, |
|
|
"available_models": model_files |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"status": "error", |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""根路径""" |
|
|
return { |
|
|
"message": "Whisper API is running", |
|
|
"version": "1.0.0", |
|
|
"endpoints": { |
|
|
"health": "/health", |
|
|
"transcribe": "/transcribe" |
|
|
} |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |