VoiceGate / scripts /workflow_client.py
YanTianlong's picture
Add TTS trim control and polish UI
1c552ae
Raw
History Blame Contribute Delete
4.98 kB
"""Small client for submitting VoiceGate workflows to a local ComfyUI API."""
from __future__ import annotations
import argparse
import copy
import json
import os
import time
import uuid
from pathlib import Path
from typing import Any
import requests
ROOT = Path(__file__).resolve().parents[1]
WORKFLOW_PATH = ROOT / "workflows" / "voicegate_api.json"
STDOUT = open(1, "w", encoding="utf-8", closefd=False)
def load_workflow(path: Path = WORKFLOW_PATH) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as file:
return json.load(file)
def patch_voicegate_workflow(
workflow: dict[str, Any],
*,
audio_filename: str,
target_language: str,
api_key: str | None,
api_baseurl: str,
llm_model: str,
job_id: str | None = None,
tts_trim_start: float = 0.0,
) -> dict[str, Any]:
patched = copy.deepcopy(workflow)
job = job_id or uuid.uuid4().hex[:12]
trim_start = min(1.0, max(0.0, float(tts_trim_start)))
patched["16"]["inputs"]["audio"] = audio_filename
patched["105"]["inputs"]["api_baseurl"] = api_baseurl
patched["105"]["inputs"]["api_key"] = api_key or ""
patched["105"]["inputs"]["model"] = llm_model
patched["110"]["inputs"]["value"] = target_language
patched["180"]["inputs"]["filename_prefix"] = f"audio/voicegate_{job}"
patched["214"]["inputs"]["filename_prefix"] = f"VoiceBridge/subtitle_{job}"
patched["31"]["inputs"].setdefault("source", "HuggingFace")
patched["31"]["inputs"]["attention"] = "sdpa"
patched["31"]["inputs"]["max_new_tokens"] = 256
patched["206"]["inputs"]["inference_steps"] = 4
patched["268"]["inputs"]["start_index"] = trim_start
return patched
def upload_audio(
server: str,
audio_path: Path,
*,
overwrite: bool = True,
) -> str:
"""Upload audio to ComfyUI and return the ComfyUI input filename.
Recent ComfyUI builds accept `/upload/image` for input file uploads across
several media types. If this changes, this function is the only place that
should need adjustment.
"""
with audio_path.open("rb") as file:
files = {"image": (audio_path.name, file, "application/octet-stream")}
data = {"overwrite": str(overwrite).lower(), "type": "input"}
response = requests.post(f"{server}/upload/image", files=files, data=data, timeout=120)
response.raise_for_status()
payload = response.json()
return payload.get("name") or audio_path.name
def submit_prompt(server: str, workflow: dict[str, Any]) -> str:
response = requests.post(
f"{server}/prompt",
json={"prompt": workflow, "client_id": str(uuid.uuid4())},
timeout=120,
)
response.raise_for_status()
payload = response.json()
return payload["prompt_id"]
def wait_for_history(server: str, prompt_id: str, timeout: float = 1800) -> dict[str, Any]:
deadline = time.time() + timeout
while time.time() < deadline:
response = requests.get(f"{server}/history/{prompt_id}", timeout=30)
response.raise_for_status()
payload = response.json()
if prompt_id in payload:
return payload[prompt_id]
time.sleep(2)
raise TimeoutError(f"Timed out waiting for prompt {prompt_id}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--server", default=os.environ.get("COMFYUI_URL", "http://127.0.0.1:8188"))
parser.add_argument("--workflow", type=Path, default=WORKFLOW_PATH)
parser.add_argument("--audio", type=Path)
parser.add_argument("--target-language", default="English")
parser.add_argument("--tts-trim-start", type=float, default=0.0)
parser.add_argument("--api-baseurl", default=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"))
parser.add_argument("--llm-model", default=os.environ.get("DEEPSEEK_MODEL", "deepseek-v4-flash"))
parser.add_argument("--dry-run", action="store_true")
return parser.parse_args()
def main() -> None:
args = parse_args()
workflow = load_workflow(args.workflow)
audio_filename = args.audio.name if args.audio else "placeholder.mp3"
if args.audio and not args.dry_run:
audio_filename = upload_audio(args.server, args.audio)
patched = patch_voicegate_workflow(
workflow,
audio_filename=audio_filename,
target_language=args.target_language,
api_key=os.environ.get("DEEPSEEK_API_KEY"),
api_baseurl=args.api_baseurl,
llm_model=args.llm_model,
tts_trim_start=args.tts_trim_start,
)
if args.dry_run:
json.dump(patched, STDOUT, ensure_ascii=False, indent=2)
STDOUT.write("\n")
return
prompt_id = submit_prompt(args.server, patched)
print(f"Submitted prompt {prompt_id}")
history = wait_for_history(args.server, prompt_id)
json.dump(history, STDOUT, ensure_ascii=False, indent=2)
STDOUT.write("\n")
if __name__ == "__main__":
main()