simler commited on
Commit
69346ea
·
verified ·
1 Parent(s): d35da60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -26
app.py CHANGED
@@ -1,61 +1,59 @@
1
  import builtins
2
  import os
3
  import shutil
4
- import logging
5
 
6
- # 1. 劫持 input
7
  builtins.input = lambda prompt="": "y"
8
 
9
  # 2. 设置环境
10
  os.environ["GENIE_DATA_DIR"] = "/app/GenieData"
11
 
12
- # 3. 手动下载数据
13
  from huggingface_hub import snapshot_download
14
  if not os.path.exists("/app/GenieData/G2P"):
15
- print("📦 Downloading GenieData...")
16
  snapshot_download(repo_id="High-Logic/Genie", allow_patterns=["GenieData/*"], local_dir="/app", local_dir_use_symlinks=False)
17
 
18
- # 4. 导入引擎
19
  from fastapi import UploadFile, File, Form
20
  import genie_tts
21
  from genie_tts.Server import app as genie_app
22
- from genie_tts.ModelManager import model_manager
23
 
24
- # --- 增强版上传端点 ---
25
  @genie_app.post("/upload_and_set")
26
  async def upload_and_set(
27
- character_name: str = Form("Default"), # 改为大写 Default
28
  prompt_text: str = Form(...),
29
  language: str = Form("zh"),
30
  file: UploadFile = File(...)
31
  ):
32
- # 强制将名字改为首字母大写,与引擎保持一致
33
  char_name = character_name.capitalize()
34
- save_path = f"/app/uploaded_custom_ref.wav"
 
35
 
36
  with open(save_path, "wb") as buffer:
37
  shutil.copyfileobj(file.file, buffer)
38
 
 
39
  try:
40
- # 调用底层 model_manager 确保设置成功
41
- model_manager.set_reference_audio(char_name, save_path, prompt_text, language)
42
- print(f"✅ Success: Character '{char_name}' updated with text: {prompt_text}")
43
- return {"message": f"Success! Character '{char_name}' style updated."}
44
  except Exception as e:
45
- print(f"❌ Error setting ref audio: {e}")
46
  return {"error": str(e)}, 500
47
 
48
- # --- 调试端点:查看当前加载了哪些角色 ---
49
- @genie_app.get("/debug_chars")
50
- async def debug_chars():
51
- return {
52
- "loaded_characters": list(model_manager.loaded_characters.keys()),
53
- "detail": {k: v.ref_audio_path for k, v in model_manager.loaded_characters.items()}
54
- }
55
 
56
  if __name__ == "__main__":
57
- # 加载角色,直接用 Default
58
- model_manager.load_character("Default", "/app", "zh")
59
-
60
- # 启动服务器
 
 
 
61
  genie_tts.start_server(host="0.0.0.0", port=7860)
 
1
  import builtins
2
  import os
3
  import shutil
 
4
 
5
+ # 1. 劫持 input 防止崩溃
6
  builtins.input = lambda prompt="": "y"
7
 
8
  # 2. 设置环境
9
  os.environ["GENIE_DATA_DIR"] = "/app/GenieData"
10
 
11
+ # 3. 手动下载权重
12
  from huggingface_hub import snapshot_download
13
  if not os.path.exists("/app/GenieData/G2P"):
14
+ print("📦 Downloading GenieData Dependencies...")
15
  snapshot_download(repo_id="High-Logic/Genie", allow_patterns=["GenieData/*"], local_dir="/app", local_dir_use_symlinks=False)
16
 
17
+ # 4. 导入库
18
  from fastapi import UploadFile, File, Form
19
  import genie_tts
20
  from genie_tts.Server import app as genie_app
 
21
 
22
+ # --- 核心:上传并设置参考音频 ---
23
  @genie_app.post("/upload_and_set")
24
  async def upload_and_set(
25
+ character_name: str = Form("Default"),
26
  prompt_text: str = Form(...),
27
  language: str = Form("zh"),
28
  file: UploadFile = File(...)
29
  ):
 
30
  char_name = character_name.capitalize()
31
+ # 存放在固定位置
32
+ save_path = f"/app/uploaded_ref.wav"
33
 
34
  with open(save_path, "wb") as buffer:
35
  shutil.copyfileobj(file.file, buffer)
36
 
37
+ print(f"🔄 Setting reference audio for {char_name}: {prompt_text}")
38
  try:
39
+ # 🟢 调用官方顶级 API,这是最稳妥的
40
+ genie_tts.set_reference_audio(char_name, save_path, prompt_text, language)
41
+ return {"message": f"Successfully updated style for {char_name}"}
 
42
  except Exception as e:
43
+ print(f"❌ Error: {e}")
44
  return {"error": str(e)}, 500
45
 
46
+ # --- 极简调试:只看当前角色状态 ---
47
+ @genie_app.get("/status")
48
+ async def status():
49
+ return {"status": "ok", "engine": "Genie-TTS V2 Pro Plus"}
 
 
 
50
 
51
  if __name__ == "__main__":
52
+ # 启动前先预载角色模型
53
+ try:
54
+ genie_tts.load_character("Default", "/app", "zh")
55
+ except Exception as e:
56
+ print(f"Initial load error: {e}")
57
+
58
+ # 启动 FastAPI 服务 (HF Space 固定 7860)
59
  genie_tts.start_server(host="0.0.0.0", port=7860)