File size: 4,975 Bytes
683b147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c552ae
683b147
 
 
1c552ae
683b147
 
 
 
 
 
 
 
d5eb1f4
dcb580e
90f8205
 
1c552ae
683b147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c552ae
683b147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c552ae
683b147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Small client for submitting VoiceGate workflows to a local ComfyUI API."""

from __future__ import annotations

import argparse
import copy
import json
import os
import time
import uuid
from pathlib import Path
from typing import Any

import requests


ROOT = Path(__file__).resolve().parents[1]
WORKFLOW_PATH = ROOT / "workflows" / "voicegate_api.json"
STDOUT = open(1, "w", encoding="utf-8", closefd=False)


def load_workflow(path: Path = WORKFLOW_PATH) -> dict[str, Any]:
    with path.open("r", encoding="utf-8") as file:
        return json.load(file)


def patch_voicegate_workflow(
    workflow: dict[str, Any],
    *,
    audio_filename: str,
    target_language: str,
    api_key: str | None,
    api_baseurl: str,
    llm_model: str,
    job_id: str | None = None,
    tts_trim_start: float = 0.0,
) -> dict[str, Any]:
    patched = copy.deepcopy(workflow)
    job = job_id or uuid.uuid4().hex[:12]
    trim_start = min(1.0, max(0.0, float(tts_trim_start)))

    patched["16"]["inputs"]["audio"] = audio_filename
    patched["105"]["inputs"]["api_baseurl"] = api_baseurl
    patched["105"]["inputs"]["api_key"] = api_key or ""
    patched["105"]["inputs"]["model"] = llm_model
    patched["110"]["inputs"]["value"] = target_language
    patched["180"]["inputs"]["filename_prefix"] = f"audio/voicegate_{job}"
    patched["214"]["inputs"]["filename_prefix"] = f"VoiceBridge/subtitle_{job}"
    patched["31"]["inputs"].setdefault("source", "HuggingFace")
    patched["31"]["inputs"]["attention"] = "sdpa"
    patched["31"]["inputs"]["max_new_tokens"] = 256
    patched["206"]["inputs"]["inference_steps"] = 4
    patched["268"]["inputs"]["start_index"] = trim_start
    return patched


def upload_audio(
    server: str,
    audio_path: Path,
    *,
    overwrite: bool = True,
) -> str:
    """Upload audio to ComfyUI and return the ComfyUI input filename.

    Recent ComfyUI builds accept `/upload/image` for input file uploads across
    several media types. If this changes, this function is the only place that
    should need adjustment.
    """

    with audio_path.open("rb") as file:
        files = {"image": (audio_path.name, file, "application/octet-stream")}
        data = {"overwrite": str(overwrite).lower(), "type": "input"}
        response = requests.post(f"{server}/upload/image", files=files, data=data, timeout=120)
    response.raise_for_status()
    payload = response.json()
    return payload.get("name") or audio_path.name


def submit_prompt(server: str, workflow: dict[str, Any]) -> str:
    response = requests.post(
        f"{server}/prompt",
        json={"prompt": workflow, "client_id": str(uuid.uuid4())},
        timeout=120,
    )
    response.raise_for_status()
    payload = response.json()
    return payload["prompt_id"]


def wait_for_history(server: str, prompt_id: str, timeout: float = 1800) -> dict[str, Any]:
    deadline = time.time() + timeout
    while time.time() < deadline:
        response = requests.get(f"{server}/history/{prompt_id}", timeout=30)
        response.raise_for_status()
        payload = response.json()
        if prompt_id in payload:
            return payload[prompt_id]
        time.sleep(2)
    raise TimeoutError(f"Timed out waiting for prompt {prompt_id}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--server", default=os.environ.get("COMFYUI_URL", "http://127.0.0.1:8188"))
    parser.add_argument("--workflow", type=Path, default=WORKFLOW_PATH)
    parser.add_argument("--audio", type=Path)
    parser.add_argument("--target-language", default="English")
    parser.add_argument("--tts-trim-start", type=float, default=0.0)
    parser.add_argument("--api-baseurl", default=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"))
    parser.add_argument("--llm-model", default=os.environ.get("DEEPSEEK_MODEL", "deepseek-v4-flash"))
    parser.add_argument("--dry-run", action="store_true")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    workflow = load_workflow(args.workflow)
    audio_filename = args.audio.name if args.audio else "placeholder.mp3"
    if args.audio and not args.dry_run:
        audio_filename = upload_audio(args.server, args.audio)

    patched = patch_voicegate_workflow(
        workflow,
        audio_filename=audio_filename,
        target_language=args.target_language,
        api_key=os.environ.get("DEEPSEEK_API_KEY"),
        api_baseurl=args.api_baseurl,
        llm_model=args.llm_model,
        tts_trim_start=args.tts_trim_start,
    )

    if args.dry_run:
        json.dump(patched, STDOUT, ensure_ascii=False, indent=2)
        STDOUT.write("\n")
        return

    prompt_id = submit_prompt(args.server, patched)
    print(f"Submitted prompt {prompt_id}")
    history = wait_for_history(args.server, prompt_id)
    json.dump(history, STDOUT, ensure_ascii=False, indent=2)
    STDOUT.write("\n")


if __name__ == "__main__":
    main()