1een commited on
Commit
705e333
·
1 Parent(s): 6a589e7
Files changed (3) hide show
  1. app.py +0 -329
  2. fixed_app.py +28 -22
  3. startup.sh +1 -1
app.py DELETED
@@ -1,329 +0,0 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- import base64
4
- import io
5
- import tempfile
6
- import os
7
- import requests
8
- from typing import Optional, List, Dict, Any
9
- import logging
10
- from urllib.parse import urlparse
11
- import time
12
- from fastapi.responses import StreamingResponse
13
- import subprocess
14
- import asyncio
15
-
16
- # 设置缓存目录
17
- os.environ['XDG_CACHE_HOME'] = '/app/.cache'
18
- # os.environ['TORCH_HOME'] = '/app/.cache/torch'
19
-
20
- # 确保缓存目录存在
21
- os.makedirs('/app/.cache', exist_ok=True)
22
- # os.makedirs('/app/.cache/torch', exist_ok=True)
23
-
24
- # 配置日志
25
- logging.basicConfig(level=logging.INFO)
26
- logger = logging.getLogger(__name__)
27
-
28
- app = FastAPI(title="Whisper API", version="1.0.0")
29
-
30
- # 启动事件:预加载模型
31
- @app.on_event("startup")
32
- async def startup_event():
33
- """应用启动时的初始化操作"""
34
- logger.info("Starting Whisper API...")
35
- try:
36
- # 不在启动时预加载模型,改为按需加载以避免启动阻塞
37
- logger.info("Whisper API ready - models will be loaded on demand")
38
- except Exception as e:
39
- logger.error(f"Startup warning: {e}")
40
- # 继续启动,不因为模型加载失败而阻塞
41
- logger.info("Whisper API startup complete")
42
-
43
- # 全局变量存储模型
44
- models = {}
45
-
46
- # 预加载模型列表
47
- PRELOAD_MODELS = ["tiny", "base", "small"]
48
-
49
- class AudioRequest(BaseModel):
50
- audio: str # base64 编码的音频数据
51
- model: str = "base" # 改为small模型,准确度更高
52
- language: Optional[str] = "zh" # 默认中文
53
- task: Optional[str] = "transcribe"
54
- temperature: Optional[float] = 0.0 # 温度越高,生成文本的随机性越大,温度越低,生成文本的随机性越小
55
- word_timestamps: Optional[bool] = False # 默认关闭词级时间戳
56
- # output_format: str = "text" # 支持 json 或 text
57
- compression_ratio_threshold: Optional[float] = 2.4 # 压缩比阈值,用于过滤掉低质量的片段
58
- logprob_threshold: Optional[float] = -1.0 # 对数概率阈值,用于过滤掉低质量的片段
59
- no_speech_threshold: Optional[float] = 0.6 # 无语音阈值,用于过滤掉无语音的片段
60
- device: Optional[str] = None
61
- fp16: Optional[bool] = False # CPU 默认关闭 fp16
62
- beam_size: Optional[int] = 1 # 默认束搜索为1
63
- condition_on_previous_text: Optional[bool] = False # 默认关闭上下文
64
-
65
- def get_device():
66
- return "cpu"
67
-
68
- def load_model(model_name: str):
69
- """确保模型文件存在,返回模型路径"""
70
- # 检查多个可能的模型路径
71
- possible_paths = [
72
- f"/app/models/ggml-{model_name}.bin",
73
- f"/app/models/{model_name}.bin",
74
- f"/app/models/for-tests-ggml-{model_name}.bin",
75
- f"/models/ggml-{model_name}.bin",
76
- f"/models/{model_name}.bin"
77
- ]
78
-
79
- # 检查是否有任何一个路径存在
80
- for path in possible_paths:
81
- if os.path.exists(path):
82
- logger.info(f"找到模型: {path}")
83
- return path
84
-
85
- # 如果没有找到,使用测试模型
86
- test_model = "/app/models/for-tests-ggml-base.bin"
87
- if os.path.exists(test_model):
88
- logger.info(f"使用测试模型: {test_model}")
89
- return test_model
90
-
91
- # 如果连测试模型都没有,报错
92
- logger.error(f"找不到模型 {model_name},请确保模型文件存在")
93
- raise HTTPException(status_code=500, detail=f"Model {model_name} not found")
94
-
95
- def preload_models():
96
- """启动时预加载模型"""
97
- # device = get_device()
98
- # logger.info(f"预加载模型到设备: {device}")
99
-
100
- total_start_time = time.time()
101
- for model_name in PRELOAD_MODELS:
102
- try:
103
- model_start_time = time.time()
104
- logger.info(f"开始预加载模型: {model_name}")
105
- load_model(model_name)
106
- model_load_time = time.time() - model_start_time
107
- logger.info(f"模型 {model_name} 预加载成功,耗时: {model_load_time:.2f}秒")
108
- except Exception as e:
109
- logger.error(f"模型 {model_name} 预加载失败: {e}")
110
- # 继续加载其他模型,不中断程序启动
111
-
112
- total_time = time.time() - total_start_time
113
- logger.info(f"所有模型预加载完成,总耗时: {total_time:.2f}秒")
114
-
115
- class TranscriptionProgressLogger:
116
- """转录进度日志记录器"""
117
- def __init__(self, request_id: str = None):
118
- self.request_id = request_id or str(int(time.time()))
119
- self.start_time = time.time()
120
- self.segment_count = 0
121
- self.last_segment_time = self.start_time
122
- self.segments_info = []
123
-
124
- def log_start(self, audio_duration: float = None):
125
- """记录转录开始"""
126
- if audio_duration:
127
- logger.info(f"[{self.request_id}] 开始转录 - 音频时长: {audio_duration:.2f}秒")
128
- else:
129
- logger.info(f"[{self.request_id}] 开始转录音频")
130
-
131
- def log_segment_progress(self, segment_id: int, start_time: float, end_time: float, text: str):
132
- """记录片段转录进度"""
133
- self.segment_count += 1
134
- current_time = time.time()
135
-
136
- # 计算从上一个片段到现在的时间
137
- segment_processing_time = current_time - self.last_segment_time
138
- self.last_segment_time = current_time
139
-
140
- # 计算总耗时
141
- total_elapsed = current_time - self.start_time
142
-
143
- # 存储片段信息
144
- self.segments_info.append({
145
- "id": segment_id,
146
- "start": start_time,
147
- "end": end_time,
148
- "duration": end_time - start_time,
149
- "processing_time": segment_processing_time
150
- })
151
-
152
- # 计算实时速度比(音频时长与处理时间的比值)
153
- segment_duration = end_time - start_time
154
- speed_ratio = segment_duration / segment_processing_time if segment_processing_time > 0 else 0
155
-
156
- # 记录日志
157
- logger.info(
158
- f"[{self.request_id}] 片段 {segment_id}/{self.segment_count} "
159
- f"({start_time:.1f}s-{end_time:.1f}s, 时长:{segment_duration:.1f}s): "
160
- f"'{text[:30]}{'...' if len(text) > 30 else ''}' "
161
- f"(处理耗时: {segment_processing_time:.2f}s, 速度比: {speed_ratio:.1f}x, 总耗时: {total_elapsed:.2f}s)"
162
- )
163
-
164
- def log_completion(self, total_segments: int, total_text_length: int):
165
- """记录转录完成"""
166
- elapsed = time.time() - self.start_time
167
-
168
- # 计算总音频时长
169
- total_audio_duration = sum(segment["duration"] for segment in self.segments_info) if self.segments_info else 0
170
-
171
- # 计算平均速度比
172
- avg_speed_ratio = total_audio_duration / elapsed if elapsed > 0 else 0
173
-
174
- # 计算每秒处理的文本量
175
- text_per_second = total_text_length / elapsed if elapsed > 0 else 0
176
-
177
- logger.info(
178
- f"[{self.request_id}] 转录完成 - "
179
- f"总片段: {total_segments}, "
180
- f"文本长度: {total_text_length}字符, "
181
- f"音频时长: {total_audio_duration:.2f}秒, "
182
- f"处理耗时: {elapsed:.2f}秒, "
183
- f"平均速度比: {avg_speed_ratio:.1f}x, "
184
- f"处理速度: {text_per_second:.1f}字/秒"
185
- )
186
-
187
- def decode_audio(audio_base64: str) -> tuple:
188
- """解码base64音频数据并保存为临时文件,返回文件路径和音频大小"""
189
- try:
190
- # 移除data URL前缀(如果存在)
191
- if "," in audio_base64:
192
- audio_base64 = audio_base64.split(",")[1]
193
-
194
- # 解码base64
195
- start_time = time.time()
196
- audio_data = base64.b64decode(audio_base64)
197
- decode_time = time.time() - start_time
198
-
199
- # 获取音频大小(字节)
200
- audio_size = len(audio_data)
201
- logger.info(f"音频解码完成: {audio_size/1024:.2f} KB, 耗时: {decode_time:.2f}s")
202
-
203
- # 创建临时文件
204
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
205
- temp_file.write(audio_data)
206
- return temp_file.name
207
- except Exception as e:
208
- logger.error(f"音频解码失败: {str(e)}")
209
- raise HTTPException(status_code=400, detail=f"Invalid audio data: {str(e)}")
210
-
211
- @app.post("/transcribe")
212
- async def transcribe_audio(request: AudioRequest):
213
- """音频转录API,异步调用 whisper.cpp 并流式返回分段结果"""
214
- try:
215
- # 解码音频并保存为临时文件
216
- audio_file = decode_audio(request.audio)
217
- model_path = load_model(request.model) # 确保模型存在
218
-
219
- # 检查whisper.cpp二进制路径
220
- whisper_binary = "/app/build/bin/main"
221
- if not os.path.exists(whisper_binary):
222
- # 尝试其他可能的路径
223
- possible_binaries = [
224
- "/app/main",
225
- "/usr/local/bin/whisper",
226
- "/usr/local/bin/whisper.cpp"
227
- ]
228
- for binary in possible_binaries:
229
- if os.path.exists(binary):
230
- whisper_binary = binary
231
- break
232
-
233
- logger.info(f"使用whisper二进制: {whisper_binary}")
234
- logger.info(f"使用模型: {model_path}")
235
-
236
- cmd = [
237
- whisper_binary, # whisper.cpp 主程序路径
238
- "-m", model_path,
239
- "-f", audio_file,
240
- "-l", request.language or "zh",
241
- "--output-json",
242
- "--print-progress",
243
- "--split-on-word",
244
- "-t", str(os.cpu_count() or 1),
245
- ]
246
- except Exception as e:
247
- logger.error(f"准备转录失败: {e}")
248
- raise HTTPException(status_code=500, detail=f"Failed to prepare transcription: {str(e)}")
249
- # 添加可选参数
250
- if request.beam_size:
251
- cmd += ["--beam-size", str(request.beam_size)]
252
- if request.temperature:
253
- cmd += ["--temperature", str(request.temperature)]
254
- # 其���参数可按需添加
255
-
256
- async def event_stream():
257
- proc = await asyncio.create_subprocess_exec(
258
- *cmd,
259
- stdout=asyncio.subprocess.PIPE,
260
- stderr=asyncio.subprocess.STDOUT,
261
- )
262
- try:
263
- async for line in proc.stdout:
264
- line = line.decode().strip()
265
- if line.startswith("{"):
266
- yield f"data: {line}\n\n"
267
- await proc.wait()
268
- finally:
269
- # 清理临时文件
270
- if os.path.exists(audio_file):
271
- os.unlink(audio_file)
272
-
273
- return StreamingResponse(event_stream(), media_type="text/event-stream")
274
-
275
- @app.get("/health")
276
- async def health_check():
277
- """健康检查"""
278
- try:
279
- # 检查whisper.cpp二进制是否存在
280
- whisper_binary = "/app/build/bin/main"
281
- binary_exists = os.path.exists(whisper_binary)
282
-
283
- # 检查模型目录
284
- model_dirs = ["/app/models", "/models"]
285
- model_files = []
286
-
287
- for dir_path in model_dirs:
288
- if os.path.exists(dir_path):
289
- try:
290
- model_files.extend([f"{dir_path}/{f}" for f in os.listdir(dir_path) if f.endswith(".bin")])
291
- except:
292
- pass
293
-
294
- return {
295
- "status": "healthy",
296
- "whisper_binary": whisper_binary,
297
- "binary_exists": binary_exists,
298
- "model_dirs": {dir_path: os.path.exists(dir_path) for dir_path in model_dirs},
299
- "available_models": model_files
300
- }
301
- except Exception as e:
302
- return {
303
- "status": "error",
304
- "error": str(e)
305
- }
306
-
307
- @app.get("/models")
308
- async def list_models():
309
- """列出可用模型"""
310
- return {
311
- "models": ["tiny", "base", "small", "medium", "large", "turbo"]
312
- }
313
-
314
- @app.get("/")
315
- async def root():
316
- """根路径"""
317
- return {
318
- "message": "Whisper API is running",
319
- "version": "1.0.0",
320
- "endpoints": {
321
- "health": "/health",
322
- "models": "/models",
323
- "transcribe": "/transcribe"
324
- }
325
- }
326
-
327
- if __name__ == "__main__":
328
- import uvicorn
329
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fixed_app.py CHANGED
@@ -8,7 +8,6 @@ from typing import Optional
8
  import logging
9
  import time
10
  import asyncio
11
- import shutil
12
 
13
  # 设置缓存目录
14
  os.environ['XDG_CACHE_HOME'] = '/app/.cache'
@@ -202,6 +201,28 @@ def parse_whisper_output(output_file: str, stdout: bytes, exit_code: int) -> dic
202
  }
203
  return result
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  @app.post("/transcribe")
206
  async def transcribe_audio(request: AudioRequest):
207
  """音频转录API,异步调用 whisper.cpp 并返回转录结果"""
@@ -254,7 +275,7 @@ async def transcribe_audio(request: AudioRequest):
254
  whisper_binary,
255
  "-m", model_path,
256
  "-f", audio_file,
257
- "-l", request.language or "zh",
258
  "-oj", # --output-json: 输出JSON格式
259
  "-of", output_file, # 指定输出文件
260
  "-t", str(request.threads), # 使用所有CPU核心
@@ -272,8 +293,6 @@ async def transcribe_audio(request: AudioRequest):
272
  if request.temperature:
273
  cmd += ["-tp", str(request.temperature)] # --temperature 的简写
274
 
275
- # logger.info(f"完整命令: {' '.join(cmd)}")
276
-
277
  try:
278
  # 执行命令
279
  start_time = time.time()
@@ -300,9 +319,9 @@ async def transcribe_audio(request: AudioRequest):
300
  logger.warning("输出包含非UTF-8字符,已替换")
301
 
302
  # 记录输出日志
303
- for line in output_text.splitlines():
304
- if line.strip():
305
- logger.info(f"whisper输出: {line.strip()}")
306
 
307
  # 检查退出码
308
  exit_code = proc.returncode
@@ -312,6 +331,7 @@ async def transcribe_audio(request: AudioRequest):
312
  # 读取JSON输出文件
313
  result = parse_whisper_output(output_file, stdout, exit_code)
314
  result["processing_time"] = f"{processing_time:.2f}"
 
315
 
316
  return result
317
 
@@ -329,21 +349,7 @@ async def transcribe_audio(request: AudioRequest):
329
  raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
330
  finally:
331
  # 清理临时文件
332
- if os.path.exists(audio_file):
333
- os.unlink(audio_file)
334
- # 如果有转换后的文件,也要清理
335
- if audio_file.endswith('_converted.wav'):
336
- original_file = audio_file.replace('_converted.wav', '.m4a')
337
- if os.path.exists(original_file):
338
- os.unlink(original_file)
339
- # 清理输出文件
340
- json_output_file = output_file + ".json"
341
- if os.path.exists(json_output_file):
342
- os.unlink(json_output_file)
343
- # 清理临时目录
344
- if os.path.exists(temp_dir):
345
- import shutil
346
- shutil.rmtree(temp_dir, ignore_errors=True)
347
  except Exception as e:
348
  logger.error(f"转录失败: {e}")
349
  raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
 
8
  import logging
9
  import time
10
  import asyncio
 
11
 
12
  # 设置缓存目录
13
  os.environ['XDG_CACHE_HOME'] = '/app/.cache'
 
201
  }
202
  return result
203
 
204
+ def cleanup_temp_files(audio_file, output_file, temp_dir):
205
+ """清理音频、输出文件和临时目录"""
206
+ try:
207
+ # 删除音频文件
208
+ if audio_file and os.path.exists(audio_file):
209
+ os.unlink(audio_file)
210
+ # 删除转换后的文件(如 _converted.wav)
211
+ if audio_file and audio_file.endswith('_converted.wav'):
212
+ original_file = audio_file.replace('_converted.wav', '.m4a')
213
+ if os.path.exists(original_file):
214
+ os.unlink(original_file)
215
+ # 删除输出JSON文件
216
+ json_output_file = output_file + ".json"
217
+ if os.path.exists(json_output_file):
218
+ os.unlink(json_output_file)
219
+ # 删除临时目录
220
+ if temp_dir and os.path.exists(temp_dir):
221
+ import shutil
222
+ shutil.rmtree(temp_dir, ignore_errors=True)
223
+ except Exception as e:
224
+ logger.warning(f"清理临时文件时出错: {e}")
225
+
226
  @app.post("/transcribe")
227
  async def transcribe_audio(request: AudioRequest):
228
  """音频转录API,异步调用 whisper.cpp 并返回转录结果"""
 
275
  whisper_binary,
276
  "-m", model_path,
277
  "-f", audio_file,
278
+ "-l", request.language or "auto",
279
  "-oj", # --output-json: 输出JSON格式
280
  "-of", output_file, # 指定输出文件
281
  "-t", str(request.threads), # 使用所有CPU核心
 
293
  if request.temperature:
294
  cmd += ["-tp", str(request.temperature)] # --temperature 的简写
295
 
 
 
296
  try:
297
  # 执行命令
298
  start_time = time.time()
 
319
  logger.warning("输出包含非UTF-8字符,已替换")
320
 
321
  # 记录输出日志
322
+ # for line in output_text.splitlines():
323
+ # if line.strip():
324
+ # logger.info(f"whisper输出: {line.strip()}")
325
 
326
  # 检查退出码
327
  exit_code = proc.returncode
 
331
  # 读取JSON输出文件
332
  result = parse_whisper_output(output_file, stdout, exit_code)
333
  result["processing_time"] = f"{processing_time:.2f}"
334
+ result["cmd"] = " ".join(cmd)
335
 
336
  return result
337
 
 
349
  raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
350
  finally:
351
  # 清理临时文件
352
+ cleanup_temp_files(audio_file, output_file, temp_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  except Exception as e:
354
  logger.error(f"转录失败: {e}")
355
  raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
startup.sh CHANGED
@@ -1,7 +1,7 @@
1
  #!/bin/bash
2
 
3
  # 显示环境信息
4
- echo "=== Whisper API Startup 0.7==="
5
  echo "Python version: $(python3 --version)"
6
  echo "Current directory: $(pwd)"
7
  # echo "Files in /app:"
 
1
  #!/bin/bash
2
 
3
  # 显示环境信息
4
+ echo "=== Whisper API Startup 0.8==="
5
  echo "Python version: $(python3 --version)"
6
  echo "Current directory: $(pwd)"
7
  # echo "Files in /app:"