simler commited on
Commit
9ae828b
·
verified ·
1 Parent(s): 7cc7551

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -52
app.py CHANGED
@@ -2,80 +2,71 @@ import os
2
  import builtins
3
  import shutil
4
  import io
 
 
 
 
 
5
 
6
- # 🔴 最高优先级:在任何其他 import 之前直接劫持!
7
  builtins.input = lambda prompt="": "y"
8
  os.environ["GENIE_DATA_DIR"] = "/app/GenieData"
9
 
10
- # 🟢 关键:先下载,下载完之前不许往下走
11
- from huggingface_hub import snapshot_download
12
  if not os.path.exists("/app/GenieData/G2P"):
13
- print("📦 Booting: Downloading weights...")
14
- snapshot_download(
15
- repo_id="High-Logic/Genie",
16
- allow_patterns=["GenieData/*"],
17
- local_dir="/app",
18
- local_dir_use_symlinks=False
19
- )
20
- print("✅ Booting: Weights ready.")
21
-
22
- # 现在才允许导入其他库
23
- import uvicorn
24
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
25
- from fastapi.responses import StreamingResponse
26
- from fastapi.concurrency import run_in_threadpool
27
- import genie_tts
28
 
29
  app = FastAPI()
30
 
31
- # 启动后再预载入模型
32
- print("⚡ Warming up model...")
33
- try:
34
- genie_tts.load_character("Default", "/app", "zh")
35
- except:
36
- pass
37
 
38
- @app.post("/upload_and_set")
39
- async def upload_and_set(
40
- character_name: str = Form("Default"),
 
 
 
41
  prompt_text: str = Form(...),
 
42
  language: str = Form("zh"),
43
  file: UploadFile = File(...)
44
  ):
45
- char_name = character_name.capitalize()
46
- save_path = f"/app/uploaded_ref.wav"
47
- with open(save_path, "wb") as b:
48
- shutil.copyfileobj(file.file, b)
49
-
50
  try:
51
- await run_in_threadpool(genie_tts.set_reference_audio, char_name, save_path, prompt_text, language)
52
- return {"message": "Style updated"}
 
 
 
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
- return {"error": str(e)}, 500
55
 
 
56
  @app.post("/tts")
57
- async def tts_endpoint(data: dict):
58
- char_name = data.get("character_name", "Default").capitalize()
59
  text = data.get("text", "")
 
60
 
61
  try:
62
- temp_out = "/app/temp_tts.wav"
63
- await run_in_threadpool(
64
- genie_tts.tts,
65
- character_name=char_name,
66
- text=text,
67
- save_path=temp_out,
68
- play=False
69
- )
70
 
71
- if os.path.exists(temp_out):
72
- with open(temp_out, "rb") as f:
73
- content = f.read()
74
- return StreamingResponse(io.BytesIO(content), media_type="audio/wav")
75
- raise Exception("Inference failed")
76
-
77
  except Exception as e:
78
- return {"detail": str(e)}, 400
79
 
80
  if __name__ == "__main__":
81
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  import builtins
3
  import shutil
4
  import io
5
+ import uvicorn
6
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
7
+ from fastapi.responses import StreamingResponse
8
+ import genie_tts
9
+ from huggingface_hub import snapshot_download
10
 
11
+ # --- 初期化屏蔽 ---
12
  builtins.input = lambda prompt="": "y"
13
  os.environ["GENIE_DATA_DIR"] = "/app/GenieData"
14
 
15
+ # 下载权重
 
16
  if not os.path.exists("/app/GenieData/G2P"):
17
+ snapshot_download(repo_id="High-Logic/Genie", allow_patterns=["GenieData/*"], local_dir="/app", local_dir_use_symlinks=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  app = FastAPI()
20
 
21
+ # 默认参考设定 (云端原始琴团长)
22
+ DEFAULT_REF_PATH = "/app/ref.wav"
23
+ DEFAULT_REF_TEXT = "琴是个称职的好团长。看到她认真工作的样子,就连我也忍不住想要多帮她一把。"
 
 
 
24
 
25
+ print("⚡ Starting Engine...")
26
+ genie_tts.load_character("Default", "/app", "zh")
27
+
28
+ # 🟢 端点 1:上传并立刻合成 (且不干扰默认值)
29
+ @app.post("/upload_and_tts")
30
+ async def upload_and_tts(
31
  prompt_text: str = Form(...),
32
+ text: str = Form(...),
33
  language: str = Form("zh"),
34
  file: UploadFile = File(...)
35
  ):
 
 
 
 
 
36
  try:
37
+ # 1. 保存临时上传的音频
38
+ temp_ref = f"/app/temp_upload_ref.wav"
39
+ with open(temp_ref, "wb") as buffer:
40
+ shutil.copyfileobj(file.file, buffer)
41
+
42
+ # 2. 设置本次参考
43
+ genie_tts.set_reference_audio("Default", temp_ref, prompt_text, language)
44
+
45
+ # 3. 推理合成
46
+ temp_out = "/app/temp_out.wav"
47
+ genie_tts.tts("Default", text, save_path=temp_out, play=False)
48
+
49
+ return StreamingResponse(open(temp_out, "rb"), media_type="audio/wav")
50
  except Exception as e:
51
+ raise HTTPException(status_code=500, detail=str(e))
52
 
53
+ # 🟢 端点 2:普通 TTS (强制回归原始琴团长)
54
  @app.post("/tts")
55
+ async def safe_tts(data: dict):
 
56
  text = data.get("text", "")
57
+ print(f"🔄 Safety Check: Forcing reset to original Qin voice...")
58
 
59
  try:
60
+ # 1. 强制重置回云端内置的 ref.wav
61
+ genie_tts.set_reference_audio("Default", DEFAULT_REF_PATH, DEFAULT_REF_TEXT, "zh")
62
+
63
+ # 2. 推理合成
64
+ temp_out = "/app/temp_std_out.wav"
65
+ genie_tts.tts("Default", text, save_path=temp_out, play=False)
 
 
66
 
67
+ return StreamingResponse(open(temp_out, "rb"), media_type="audio/wav")
 
 
 
 
 
68
  except Exception as e:
69
+ raise HTTPException(status_code=404, detail=str(e))
70
 
71
  if __name__ == "__main__":
72
  uvicorn.run(app, host="0.0.0.0", port=7860)