| from pydantic import BaseModel, Field |
| from typing import Optional, Dict, Any |
| import json, uuid, time, os |
| import requests |
| import websocket |
| from urllib.parse import urlencode |
| import gradio as gr |
|
|
| COMFY_HOST = os.getenv("COMFY_HOST", "134.199.132.159") |
|
|
| with open("workflow.json", "r", encoding="utf-8") as f: |
| WORKFLOW_TEMPLATE: Dict[str, Any] = json.load(f) |
|
|
| class T2VReq(BaseModel): |
| token: str = Field(...) |
| text: str = Field(...) |
| negative: Optional[str] = None |
| seed: Optional[int] = None |
| steps: Optional[int] = 4 |
| cfg: Optional[float] = 1 |
| width: Optional[int] = 640 |
| height: Optional[int] = 640 |
| length: Optional[int] = 81 |
| fps: Optional[int] = 16 |
| filename_prefix: Optional[str] = "video/ComfyUI" |
|
|
| def _inject_params(prompt: Dict[str, Any], r: T2VReq) -> Dict[str, Any]: |
| p = json.loads(json.dumps(prompt)) |
| p["89"]["inputs"]["text"] = r.text |
| |
| |
| |
| |
| |
| if r.width is not None: p["74"]["inputs"]["width"] = r.width |
| if r.height is not None: p["74"]["inputs"]["height"] = r.height |
| if r.length is not None: p["74"]["inputs"]["length"] = r.length |
| if r.fps is not None: p["88"]["inputs"]["fps"] = r.fps |
| if r.filename_prefix: |
| p["80"]["inputs"]["filename_prefix"] = r.filename_prefix |
| return p |
|
|
| def _open_ws(client_id: str, token: str): |
| ws = websocket.WebSocket() |
| ws.connect(f"ws://{COMFY_HOST}/ws?clientId={client_id}&token={token}", timeout=1800) |
| return ws |
|
|
| def _queue_prompt(prompt: Dict[str, Any], client_id: str, token: str) -> str: |
| payload = {"prompt": prompt, "client_id": client_id} |
| resp = requests.post(f"http://{COMFY_HOST}/prompt?token={token}", json=payload, timeout=1800) |
| if resp.status_code != 200: |
| raise RuntimeError(f"ComfyUI /prompt err: {resp.text}") |
| data = resp.json() |
| if "prompt_id" not in data: |
| raise RuntimeError(f"/prompt no prompt_id: {data}") |
| return data["prompt_id"] |
|
|
| def _get_history(prompt_id: str, token: str) -> Dict[str, Any]: |
| r = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}?token={token}", timeout=1800) |
| r.raise_for_status() |
| hist = r.json() |
| return hist.get(prompt_id, {}) |
|
|
| def _extract_video_from_history(history: Dict[str, Any]) -> Dict[str, str]: |
| outputs = history.get("outputs", {}) |
| for _, node_out in outputs.items(): |
| if "images" in node_out: |
| for it in node_out["images"]: |
| if all(k in it for k in ("filename", "subfolder", "type")): |
| fn = it["filename"] |
| if fn.lower().endswith((".mp4", ".webm", ".gif", ".mov", ".mkv")): |
| return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} |
| for key in ("videos", "files"): |
| if key in node_out and node_out[key]: |
| it = node_out[key][0] |
| if all(k in it for k in ("filename", "subfolder", "type")): |
| return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} |
| raise RuntimeError("No video file found in history outputs") |
|
|
| sample_prompts = [ |
| "A golden retriever running across a beach at sunset, cinematic", |
| "A cyberpunk city street at night with neon lights, light rain, slow pan", |
| "An astronaut walking on an alien planet covered in glowing crystals, purple sky with two moons, dust particles floating, slow panning shot, highly detailed, cinematic atmosphere.", |
| "A cat gracefully jumping between rooftops in slow motion, warm sunset lighting, camera tracking the cat midair, cinematic composition, natural movement." |
| ] |
|
|
| with gr.Blocks( |
| title="T2V UI", |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue", neutral_hue="slate"), |
| ) as demo: |
| |
| |
| |
| |
| gr.Markdown("### Prompt") |
| text = gr.Textbox(label="Prompt", placeholder="Describe the video you want", lines=3) |
| |
| gr.Examples(examples=sample_prompts, inputs=text) |
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| with gr.Row(): |
| width = gr.Number(label="Width", value=640, precision=0) |
| height = gr.Number(label="Height", value=640, precision=0) |
| with gr.Row(): |
| length = gr.Number(label="Frames", value=81, precision=0) |
| fps = gr.Number(label="FPS", value=8, precision=0) |
| with gr.Row(): |
| steps = gr.Number(label="Steps", value=4, precision=0) |
| cfg = gr.Number(label="CFG", value=5.0) |
| seed = gr.Number(label="Seed (optional)", value=None) |
| filename_prefix = gr.Textbox(label="Filename prefix", value="video/ComfyUI") |
| st_token = gr.Textbox(label="token", placeholder="name") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| run_btn = gr.Button("Generate", variant="primary", scale=1) |
| prog_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, step=1, interactive=False) |
| with gr.Column(scale=1): |
| out_video = gr.Video(label="Result", height=480) |
|
|
| def _init_token(): |
| return str(uuid.uuid4()) |
|
|
| demo.load(_init_token, outputs=st_token) |
|
|
| def generate_fn(text, width, height, length, fps, steps, cfg, seed, filename_prefix, token): |
| req = T2VReq( |
| token=token, |
| text=text, |
| seed=int(seed) if seed is not None else None, |
| steps=int(steps) if steps is not None else None, |
| cfg=float(cfg) if cfg is not None else None, |
| width=int(width) if width is not None else None, |
| height=int(height) if height is not None else None, |
| length=int(length) if length is not None else None, |
| fps=int(fps) if fps is not None else None, |
| filename_prefix=filename_prefix if filename_prefix else None, |
| ) |
| prompt = _inject_params(WORKFLOW_TEMPLATE, req) |
| client_id = str(uuid.uuid4()) |
| ws = _open_ws(client_id, req.token) |
| prompt_id = _queue_prompt(prompt, client_id, req.token) |
| total_nodes = max(1, len(prompt)) |
| seen = set() |
| p = 0 |
| last_emit = -1 |
| start = time.time() |
| ws.settimeout(180) |
| while True: |
| out = ws.recv() |
| if isinstance(out, (bytes, bytearray)): |
| if p < 95 and time.time() - start > 2: |
| p = min(95, p + 1) |
| if p != last_emit: |
| last_emit = p |
| yield p, None |
| continue |
| msg = json.loads(out) |
| if msg.get("type") == "executing": |
| data = msg.get("data", {}) |
| if data.get("prompt_id") != prompt_id: |
| continue |
| node = data.get("node") |
| if node is None: |
| break |
| if node not in seen: |
| seen.add(node) |
| p = min(99, int(len(seen) / total_nodes * 100)) |
| if p != last_emit: |
| last_emit = p |
| yield p, None |
| ws.close() |
| hist = _get_history(prompt_id, req.token) |
| info = _extract_video_from_history(hist) |
| q = urlencode(info) |
| video_url = f"http://{COMFY_HOST}/view?{q}&token={req.token}" |
| yield 100, video_url |
|
|
| run_btn.click( |
| generate_fn, |
| inputs=[text, width, height, length, fps, steps, cfg, seed, filename_prefix, st_token], |
| outputs=[prog_bar, out_video] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |