| from pydantic import BaseModel, Field |
| from typing import Optional, Dict, Any |
| import json, uuid, time, os |
| import requests |
| import websocket |
| from urllib.parse import urlencode |
| import gradio as gr |
|
|
| COMFY_HOST = os.getenv("COMFY_HOST", "134.199.132.159") |
|
|
| with open("workflow.json", "r", encoding="utf-8") as f: |
| WORKFLOW_TEMPLATE: Dict[str, Any] = json.load(f) |
|
|
| class T2VReq(BaseModel): |
| token: str = Field(...) |
| text: str = Field(...) |
| width: Optional[int] = 640 |
| height: Optional[int] = 640 |
| length: Optional[int] = 81 |
| fps: Optional[int] = 16 |
| filename_prefix: Optional[str] = "video/ComfyUI" |
|
|
| def _inject_params(prompt: Dict[str, Any], r: T2VReq) -> Dict[str, Any]: |
| p = json.loads(json.dumps(prompt)) |
| for node_id, node_data in p.items(): |
| class_type = node_data.get("class_type", "") |
| |
| if class_type == "CLIPTextEncode" and "text" in node_data.get("inputs", {}): |
| if "色调艳丽" not in node_data["inputs"]["text"]: |
| node_data["inputs"]["text"] = r.text |
| elif class_type == "EmptyLatentImage": |
| if r.width is not None: node_data["inputs"]["width"] = r.width |
| if r.height is not None: node_data["inputs"]["height"] = r.height |
| if r.length is not None: node_data["inputs"]["batch_size"] = r.length |
| elif class_type == "CreateVideo": |
| if r.fps is not None: node_data["inputs"]["fps"] = r.fps |
| elif class_type == "SaveVideo": |
| if r.filename_prefix: node_data["inputs"]["filename_prefix"] = r.filename_prefix |
| if r.fps is not None: node_data["inputs"]["fps"] = r.fps |
| |
| return p |
|
|
| def _open_ws(client_id: str, token: str): |
| ws = websocket.WebSocket() |
| ws.connect(f"ws://{COMFY_HOST}/ws?clientId={client_id}&token={token}", timeout=1800) |
| return ws |
|
|
| def _queue_prompt(prompt: Dict[str, Any], client_id: str, token: str) -> str: |
| payload = {"prompt": prompt, "client_id": client_id} |
| resp = requests.post(f"http://{COMFY_HOST}/prompt?token={token}", json=payload, timeout=1800) |
| if resp.status_code != 200: |
| raise RuntimeError(f"ComfyUI /prompt err: {resp.text}") |
| return resp.json()["prompt_id"] |
|
|
| def _get_history(prompt_id: str, token: str) -> Dict[str, Any]: |
| r = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}?token={token}", timeout=1800) |
| r.raise_for_status() |
| return r.json().get(prompt_id, {}) |
|
|
| def _extract_video_from_history(history: Dict[str, Any]) -> Dict[str, str]: |
| if "outputs" not in history: |
| raise RuntimeError(f"Server crashed before creating outputs! History dump: {json.dumps(history)}") |
| |
| outputs = history.get("outputs", {}) |
| for _, node_out in outputs.items(): |
| for key in ("images", "videos", "files", "gifs"): |
| if key in node_out: |
| for it in node_out[key]: |
| if isinstance(it, dict) and it.get("filename", "").lower().endswith((".mp4", ".webm", ".gif", ".mov")): |
| return {"filename": it["filename"], "subfolder": it.get("subfolder", ""), "type": it.get("type", "output")} |
| |
| raise RuntimeError(f"Video not found! Node outputs were: {json.dumps(outputs)}") |
|
|
| sample_prompts = [ |
| "A majestic cinematic shot of Lord Shiva meditating in the snowy Himalayas, cosmic universe and stars swirling around him, highly detailed, photorealistic", |
| "An ancient, mysterious temple glowing with spiritual energy in the middle of a dark forest, 4k resolution, slow pan", |
| "A cinematic view of the cosmos, colorful galaxies and stardust forming sacred mysteries, majestic slow motion", |
| "A cyberpunk city street at night with neon lights, light rain, slow cinematic tracking shot" |
| ] |
|
|
| with gr.Blocks(title="T2V UI", theme=gr.themes.Soft(primary_hue="blue")) as demo: |
| gr.Markdown("# Experience Wan2.2 14B Text-to-Video") |
| |
| text = gr.Textbox(label="Prompt", placeholder="Describe the video you want...", lines=3) |
| gr.Examples(examples=sample_prompts, inputs=text) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| with gr.Row(): |
| width = gr.Number(label="Width", value=640, precision=0) |
| height = gr.Number(label="Height", value=640, precision=0) |
| with gr.Row(): |
| length = gr.Number(label="Frames", value=81, precision=0) |
| fps = gr.Number(label="FPS", value=16, precision=0) |
| filename_prefix = gr.Textbox(label="Filename prefix", value="video/ComfyUI") |
| st_token = gr.Textbox(label="token", placeholder="name") |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| run_btn = gr.Button("Generate Video", variant="primary") |
| prog_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, interactive=False) |
| with gr.Column(scale=1): |
| out_video = gr.Video(label="Result", height=480) |
|
|
| def _init_token(): |
| return str(uuid.uuid4()) |
|
|
| demo.load(_init_token, outputs=st_token) |
|
|
| def generate_fn(text, width, height, length, fps, filename_prefix, token): |
| if not text: |
| raise gr.Error("Please enter a prompt first!") |
| |
| req = T2VReq( |
| token=token, |
| text=text, |
| width=int(width) if width else 640, |
| height=int(height) if height else 640, |
| length=int(length) if length else 81, |
| fps=int(fps) if fps else 16, |
| filename_prefix=filename_prefix |
| ) |
| |
| prompt = _inject_params(WORKFLOW_TEMPLATE, req) |
| |
| client_id = str(uuid.uuid4()) |
| ws = _open_ws(client_id, req.token) |
| prompt_id = _queue_prompt(prompt, client_id, req.token) |
| |
| total_nodes = max(1, len(prompt)) |
| seen = set() |
| p = 0 |
| last_emit = -1 |
| |
| try: |
| while True: |
| out = ws.recv() |
| if isinstance(out, bytes): |
| continue |
| msg = json.loads(out) |
| |
| |
| if msg.get("type") == "execution_error": |
| ws.close() |
| err_data = msg.get("data", {}) |
| node_type = err_data.get("node_type", "Unknown Node") |
| exc_msg = err_data.get("exception_message", "Unknown Error") |
| raise gr.Error(f"❌ Server Crashed at '{node_type}': {exc_msg}") |
| |
| if msg.get("type") == "executing": |
| data = msg.get("data", {}) |
| if data.get("prompt_id") != prompt_id: |
| continue |
| node = data.get("node") |
| if node is None: |
| break |
| if node not in seen: |
| seen.add(node) |
| p = min(99, int(len(seen) / total_nodes * 100)) |
| if p != last_emit: |
| last_emit = p |
| yield p, None |
| finally: |
| ws.close() |
| |
| hist = _get_history(prompt_id, req.token) |
| info = _extract_video_from_history(hist) |
| q = urlencode(info) |
| video_url = f"http://{COMFY_HOST}/view?{q}&token={req.token}" |
| yield 100, video_url |
|
|
| run_btn.click( |
| generate_fn, |
| inputs=[text, width, height, length, fps, filename_prefix, st_token], |
| outputs=[prog_bar, out_video] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) |
|
|