gpt-sovits / ui.py
huanx's picture
Upload ui.py with huggingface_hub
17609fa verified
import glob
import html
import os
import shutil
import sys
import uuid
from pathlib import Path
from typing import Any
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.responses import HTMLResponse, JSONResponse
import uvicorn
SPACE_TTS_CONFIG = os.getenv("SPACE_TTS_CONFIG", "GPT_SoVITS/configs/tts_infer_cpu.yaml")
SPACE_PORT = os.getenv("PORT", "7860")
sys.argv = [
"api_v2.py",
"-a",
"0.0.0.0",
"-p",
SPACE_PORT,
"-c",
SPACE_TTS_CONFIG,
]
import api_v2
app = FastAPI(title="GPT-SoVITS Space")
PRETRAINED_DIR = "/app/GPT_SoVITS/pretrained_models"
CUSTOM_DIR = "/data/models"
UPLOAD_DIR = Path("/tmp/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
DEFAULT_REF_AUDIO = Path("/data/models/zh_vo_MAIN_YHX_2_12.wav")
FAST_LANGDETECT_MODEL = Path("/app/GPT_SoVITS/pretrained_models/fast_langdetect/lid.176.bin")
DEFAULT_REF_TEXT = "是吗,抱歉哦,我不记得了。"
DEFAULT_PROMPT_LANG = "zh"
DEFAULT_TEXT_LANG = "zh"
DEFAULT_CHARACTER = "aimisi"
DEFAULT_EMOTION = "default"
def get_models():
models = []
for directory in [CUSTOM_DIR, PRETRAINED_DIR]:
if os.path.exists(directory):
for pattern in ("*.ckpt", "*.pth"):
for path in glob.glob(os.path.join(directory, "**", pattern), recursive=True):
models.append({"name": os.path.basename(path), "path": path})
models.sort(key=lambda item: item["name"])
return models
def get_languages():
languages = list(getattr(api_v2.tts_config, "languages", []))
if not languages:
languages = [DEFAULT_TEXT_LANG, DEFAULT_PROMPT_LANG, "auto"]
return sorted({lang.lower() for lang in languages})
def current_model_paths():
configs = getattr(api_v2.tts_pipeline, "configs", None)
return {
"gpt": getattr(configs, "t2s_weights_path", ""),
"sovits": getattr(configs, "vits_weights_path", ""),
}
def language_options(selected: str):
options = []
for language in get_languages():
chosen = " selected" if language == selected else ""
label = html.escape(language)
options.append(f'<option value="{label}"{chosen}>{label}</option>')
return "".join(options)
def normalize_language(language: str | None) -> str:
if not language:
return "auto"
return language.lower()
def build_tts_request(
text: str,
text_lang: str | None = None,
prompt_lang: str | None = None,
ref_text: str | None = None,
ref_audio_path: str | Path | None = None,
media_type: str | None = None,
speed: float | None = None,
top_k: int | None = None,
top_p: float | None = None,
temperature: float | None = None,
batch_size: int | None = None,
stream: bool | None = None,
text_split_method: str | None = None,
batch_threshold: float | None = None,
split_bucket: bool | None = None,
speed_factor: float | None = None,
fragment_interval: float | None = None,
seed: int | None = None,
parallel_infer: bool | None = None,
repetition_penalty: float | None = None,
) -> dict[str, Any]:
request = {
"text": text.strip(),
"text_lang": normalize_language(text_lang),
"ref_audio_path": str(ref_audio_path or DEFAULT_REF_AUDIO),
"aux_ref_audio_paths": [],
"prompt_text": (ref_text or DEFAULT_REF_TEXT).strip(),
"prompt_lang": normalize_language(prompt_lang or DEFAULT_PROMPT_LANG),
"media_type": (media_type or "wav").lower(),
"streaming_mode": bool(stream) if stream is not None else False,
"text_split_method": text_split_method or "cut5",
"batch_threshold": float(batch_threshold) if batch_threshold is not None else 0.75,
"split_bucket": bool(split_bucket) if split_bucket is not None else True,
"seed": int(seed) if seed is not None else -1,
"parallel_infer": bool(parallel_infer) if parallel_infer is not None else True,
"repetition_penalty": float(repetition_penalty) if repetition_penalty is not None else 1.35,
}
if speed is not None:
request["speed_factor"] = float(speed)
elif speed_factor is not None:
request["speed_factor"] = float(speed_factor)
if top_k is not None:
request["top_k"] = int(top_k)
if top_p is not None:
request["top_p"] = float(top_p)
if temperature is not None:
request["temperature"] = float(temperature)
if batch_size is not None:
request["batch_size"] = int(batch_size)
if fragment_interval is not None:
request["fragment_interval"] = float(fragment_interval)
return request
@app.get("/", response_class=HTMLResponse)
def index():
models = get_models()
model_paths = current_model_paths()
default_ref_status = "available" if DEFAULT_REF_AUDIO.exists() else "missing"
return f"""<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>GPT-SoVITS Space</title>
<style>
:root {{
--bg: #0d1117;
--panel: #161b22;
--panel-2: #1f2733;
--text: #e6edf3;
--muted: #9fb0c0;
--accent: #ff8a3d;
--accent-2: #ffd166;
--ok: #6fd3a3;
--border: rgba(255,255,255,0.08);
}}
* {{ box-sizing: border-box; }}
body {{
margin: 0;
font-family: "Segoe UI", "PingFang SC", "Noto Sans CJK SC", sans-serif;
background:
radial-gradient(circle at top right, rgba(255,138,61,0.18), transparent 28%),
radial-gradient(circle at left top, rgba(255,209,102,0.12), transparent 24%),
var(--bg);
color: var(--text);
min-height: 100vh;
}}
main {{
max-width: 860px;
margin: 0 auto;
padding: 32px 20px 48px;
}}
h1 {{
margin: 0 0 8px;
font-size: clamp(28px, 5vw, 44px);
}}
p {{
color: var(--muted);
line-height: 1.6;
}}
.grid {{
display: grid;
gap: 18px;
grid-template-columns: repeat(auto-fit, minmax(260px, 1fr));
margin: 20px 0 26px;
}}
.card {{
background: linear-gradient(180deg, rgba(255,255,255,0.02), rgba(255,255,255,0.01));
border: 1px solid var(--border);
border-radius: 18px;
padding: 18px;
backdrop-filter: blur(8px);
}}
.eyebrow {{
color: var(--accent-2);
text-transform: uppercase;
letter-spacing: 0.12em;
font-size: 12px;
margin-bottom: 8px;
}}
code {{
word-break: break-all;
color: var(--accent-2);
}}
form {{
background: var(--panel);
border: 1px solid var(--border);
border-radius: 22px;
padding: 22px;
}}
label {{
display: block;
margin-top: 14px;
margin-bottom: 8px;
font-size: 14px;
color: var(--muted);
}}
input, textarea, select, button {{
width: 100%;
border-radius: 14px;
border: 1px solid var(--border);
background: var(--panel-2);
color: var(--text);
padding: 12px 14px;
font-size: 15px;
}}
textarea {{
min-height: 132px;
resize: vertical;
}}
.row {{
display: grid;
gap: 14px;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
}}
button {{
margin-top: 18px;
background: linear-gradient(135deg, var(--accent), var(--accent-2));
color: #111;
font-weight: 700;
cursor: pointer;
border: none;
}}
#status {{
margin-top: 16px;
color: var(--ok);
min-height: 24px;
}}
audio {{
width: 100%;
margin-top: 16px;
display: none;
}}
.links {{
display: flex;
flex-wrap: wrap;
gap: 12px;
margin-top: 18px;
}}
.links a {{
color: var(--accent-2);
text-decoration: none;
}}
</style>
</head>
<body>
<main>
<div class="eyebrow">Hugging Face Space</div>
<h1>GPT-SoVITS CPU Inference</h1>
<p>
The Space now boots through a small FastAPI wrapper instead of exposing a bare API root.
If you do not upload a reference clip, it falls back to the built-in Aimisi sample at
<code>{html.escape(str(DEFAULT_REF_AUDIO))}</code>.
</p>
<section class="grid">
<div class="card">
<div class="eyebrow">Current GPT</div>
<code>{html.escape(model_paths["gpt"])}</code>
</div>
<div class="card">
<div class="eyebrow">Current SoVITS</div>
<code>{html.escape(model_paths["sovits"])}</code>
</div>
<div class="card">
<div class="eyebrow">Reference Sample</div>
<p>Status: {default_ref_status}<br>Loaded models: {len(models)}</p>
</div>
</section>
<form id="tts-form" enctype="multipart/form-data">
<label for="audio">Reference audio upload (optional)</label>
<input id="audio" type="file" name="audio" accept="audio/*">
<label for="ref_text">Reference text</label>
<textarea id="ref_text" name="ref_text">{html.escape(DEFAULT_REF_TEXT)}</textarea>
<div class="row">
<div>
<label for="prompt_lang">Reference language</label>
<select id="prompt_lang" name="prompt_lang">{language_options(DEFAULT_PROMPT_LANG)}</select>
</div>
<div>
<label for="text_lang">Target language</label>
<select id="text_lang" name="text_lang">{language_options(DEFAULT_TEXT_LANG)}</select>
</div>
</div>
<label for="text">Text to synthesize</label>
<textarea id="text" name="text" placeholder="输入你要合成的文本,或者直接写日语内容。"></textarea>
<button type="submit">Generate Audio</button>
<div id="status">Ready</div>
<audio id="player" controls></audio>
</form>
<div class="links">
<a href="/docs" target="_blank" rel="noreferrer">Open FastAPI docs</a>
<a href="/api/health" target="_blank" rel="noreferrer">Health JSON</a>
<a href="/api/models" target="_blank" rel="noreferrer">Model JSON</a>
<a href="/character_list" target="_blank" rel="noreferrer">GSVI character_list</a>
</div>
</main>
<script>
const form = document.getElementById("tts-form");
const status = document.getElementById("status");
const player = document.getElementById("player");
form.addEventListener("submit", async (event) => {{
event.preventDefault();
const data = new FormData(form);
status.textContent = "Generating...";
player.style.display = "none";
player.removeAttribute("src");
try {{
const response = await fetch("/ui/tts", {{
method: "POST",
body: data,
}});
if (!response.ok) {{
const payload = await response.json();
status.textContent = payload.message || payload.Exception || "Generation failed";
return;
}}
const blob = await response.blob();
player.src = URL.createObjectURL(blob);
player.style.display = "block";
status.textContent = "Done";
}} catch (error) {{
status.textContent = error.message || "Request failed";
}}
}});
</script>
</body>
</html>"""
@app.post("/ui/tts")
async def ui_tts(
audio: UploadFile | None = File(None),
ref_text: str = Form(DEFAULT_REF_TEXT),
text: str = Form(...),
text_lang: str = Form(DEFAULT_TEXT_LANG),
prompt_lang: str = Form(DEFAULT_PROMPT_LANG),
):
if not text.strip():
return JSONResponse(status_code=400, content={"message": "text is required"})
if audio and audio.filename:
suffix = Path(audio.filename).suffix or ".wav"
ref_audio_path = UPLOAD_DIR / f"{uuid.uuid4()}{suffix}"
with ref_audio_path.open("wb") as handle:
shutil.copyfileobj(audio.file, handle)
elif DEFAULT_REF_AUDIO.exists():
ref_audio_path = DEFAULT_REF_AUDIO
else:
return JSONResponse(status_code=400, content={"message": "reference audio is required"})
req = build_tts_request(
text=text,
text_lang=text_lang,
prompt_lang=prompt_lang,
ref_text=ref_text,
ref_audio_path=ref_audio_path,
media_type="wav",
stream=False,
)
return await api_v2.tts_handle(req)
@app.get("/character_list")
def character_list():
return {
DEFAULT_CHARACTER: [DEFAULT_EMOTION],
}
@app.get("/tts")
async def gsvi_tts_get(
text: str = "",
character: str = DEFAULT_CHARACTER,
emotion: str = DEFAULT_EMOTION,
text_language: str = "auto",
format: str = "wav",
top_k: int | None = None,
top_p: float | None = None,
batch_size: int | None = None,
speed: float | None = None,
temperature: float | None = None,
stream: bool = False,
text_split_method: str | None = None,
batch_threshold: float | None = None,
split_bucket: bool | None = None,
speed_factor: float | None = None,
fragment_interval: float | None = None,
seed: int | None = None,
parallel_infer: bool | None = None,
repetition_penalty: float | None = None,
):
if not text.strip():
return JSONResponse(status_code=400, content={"message": "text is required"})
req = build_tts_request(
text=text,
text_lang=text_language,
media_type=format,
speed=speed,
top_k=top_k,
top_p=top_p,
temperature=temperature,
batch_size=batch_size,
stream=stream,
text_split_method=text_split_method,
batch_threshold=batch_threshold,
split_bucket=split_bucket,
speed_factor=speed_factor,
fragment_interval=fragment_interval,
seed=seed,
parallel_infer=parallel_infer,
repetition_penalty=repetition_penalty,
)
req["character"] = character
req["emotion"] = emotion
return await api_v2.tts_handle(req)
@app.post("/tts")
async def gsvi_tts_post(request: Request):
try:
payload = await request.json()
except Exception:
return JSONResponse(status_code=400, content={"message": "invalid json body"})
text = str(payload.get("text", "")).strip()
if not text:
return JSONResponse(status_code=400, content={"message": "text is required"})
req = build_tts_request(
text=text,
text_lang=payload.get("text_language", "auto"),
media_type=payload.get("format", "wav"),
speed=payload.get("speed"),
top_k=payload.get("top_k"),
top_p=payload.get("top_p"),
temperature=payload.get("temperature"),
batch_size=payload.get("batch_size"),
stream=payload.get("stream", False),
text_split_method=payload.get("text_split_method"),
batch_threshold=payload.get("batch_threshold"),
split_bucket=payload.get("split_bucket"),
speed_factor=payload.get("speed_factor"),
fragment_interval=payload.get("fragment_interval"),
seed=payload.get("seed"),
parallel_infer=payload.get("parallel_infer"),
repetition_penalty=payload.get("repetition_penalty"),
)
req["character"] = payload.get("character", DEFAULT_CHARACTER)
req["emotion"] = payload.get("emotion", DEFAULT_EMOTION)
return await api_v2.tts_handle(req)
@app.get("/api/models")
def api_models():
return {
"current": current_model_paths(),
"available": get_models(),
}
@app.get("/api/health")
def health():
model_paths = current_model_paths()
return {
"status": "ok",
"loaded_models": len(get_models()),
"default_reference_audio": str(DEFAULT_REF_AUDIO),
"default_reference_audio_exists": DEFAULT_REF_AUDIO.exists(),
"fast_langdetect_model": str(FAST_LANGDETECT_MODEL),
"fast_langdetect_model_exists": FAST_LANGDETECT_MODEL.exists(),
"current_gpt": model_paths["gpt"],
"current_sovits": model_paths["sovits"],
}
# Start the server
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(SPACE_PORT))