Spaces:
Running
Running
| from pydantic import BaseModel, Field | |
| from typing import Optional, Dict, Any | |
| import json, uuid, time, os | |
| import requests | |
| import websocket | |
| from urllib.parse import urlencode | |
| import gradio as gr | |
| COMFY_HOST = os.getenv("COMFY_HOST", "134.199.132.159") | |
| with open("workflow.json", "r", encoding="utf-8") as f: | |
| WORKFLOW_TEMPLATE: Dict[str, Any] = json.load(f) | |
| class T2VReq(BaseModel): | |
| token: str = Field(...) | |
| text: str = Field(...) | |
| negative: Optional[str] = None | |
| seed: Optional[int] = None | |
| steps: Optional[int] = 20 | |
| cfg: Optional[float] = 5.0 | |
| width: Optional[int] = 1280 | |
| height: Optional[int] = 704 | |
| length: Optional[int] = 121 | |
| fps: Optional[int] = 24 | |
| filename_prefix: Optional[str] = "video/ComfyUI" | |
| def _inject_params(prompt: Dict[str, Any], r: T2VReq) -> Dict[str, Any]: | |
| p = json.loads(json.dumps(prompt)) | |
| p["6"]["inputs"]["text"] = r.text | |
| if r.seed is None: | |
| r.seed = int.from_bytes(os.urandom(8), "big") & ((1 << 63) - 1) | |
| p["3"]["inputs"]["seed"] = r.seed | |
| if r.steps is not None: p["3"]["inputs"]["steps"] = r.steps | |
| if r.cfg is not None: p["3"]["inputs"]["cfg"] = r.cfg | |
| if r.width is not None: p["55"]["inputs"]["width"] = r.width | |
| if r.height is not None: p["55"]["inputs"]["height"] = r.height | |
| if r.length is not None: p["55"]["inputs"]["length"] = r.length | |
| if r.fps is not None: p["57"]["inputs"]["fps"] = r.fps | |
| if r.filename_prefix: | |
| p["58"]["inputs"]["filename_prefix"] = r.filename_prefix | |
| return p | |
| def _open_ws(client_id: str, token: str): | |
| ws = websocket.WebSocket() | |
| ws.connect(f"ws://{COMFY_HOST}/ws?clientId={client_id}&token={token}", timeout=1800) | |
| return ws | |
| def _queue_prompt(prompt: Dict[str, Any], client_id: str, token: str) -> str: | |
| payload = {"prompt": prompt, "client_id": client_id} | |
| resp = requests.post(f"http://{COMFY_HOST}/prompt?token={token}", json=payload, timeout=1800) | |
| if resp.status_code != 200: | |
| raise RuntimeError(f"ComfyUI /prompt err: {resp.text}") | |
| data = resp.json() | |
| if "prompt_id" not in data: | |
| raise RuntimeError(f"/prompt no prompt_id: {data}") | |
| return data["prompt_id"] | |
| def _get_history(prompt_id: str, token: str) -> Dict[str, Any]: | |
| r = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}?token={token}", timeout=1800) | |
| r.raise_for_status() | |
| hist = r.json() | |
| return hist.get(prompt_id, {}) | |
| def _extract_video_from_history(history: Dict[str, Any]) -> Dict[str, str]: | |
| outputs = history.get("outputs", {}) | |
| for _, node_out in outputs.items(): | |
| if "images" in node_out: | |
| for it in node_out["images"]: | |
| if all(k in it for k in ("filename", "subfolder", "type")): | |
| fn = it["filename"] | |
| if fn.lower().endswith((".mp4", ".webm", ".gif", ".mov", ".mkv")): | |
| return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} | |
| for key in ("videos", "files"): | |
| if key in node_out and node_out[key]: | |
| it = node_out[key][0] | |
| if all(k in it for k in ("filename", "subfolder", "type")): | |
| return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} | |
| raise RuntimeError("No video file found in history outputs") | |
| sample_prompts = [ | |
| "A golden retriever running across a beach at sunset, cinematic, 24fps", | |
| "Aerial shot, warm colors, extreme wide shot, sunny lighting, hard lighting, daylight, establishing shot.In a desolate desert, a black SUV speeds along a highway. In a high-angle shot, the vehicle is seen driving on the left side of the road, with a roof rack and a red taillight on top. The camera slowly pushes in. In front of the vehicle are vast yellow sand dunes, and a few mountain peaks can be seen in the distance. The sky is a pale blue, and sunlight filters through the clouds, bringing a touch of warmth to the desolate land. The sides of the road are lined with dry grass dotted with some low shrubs.", | |
| "Pixel art style. In a colorful universe, a player-controlled pixel character travels between planets of various shapes and unique color tones, each with strange terrain and alien creatures. A close-up shot shows the player character in the center of the frame, in dialogue with a friendly alien creature. The alien has a rounded body and large eyes, appearing very cute. Above, pixelated cosmic storms and energy vortex effects rotate slowly, adding a sense of dynamism. The overall style is retro yet futuristic, with a vibrant and lively color palette.", | |
| "In an oil painting style, a vast sea of sunflowers unfolds, their golden heads blooming brilliantly in the faint light of dawn or dusk. The impasto technique is prominently used, with bold, thick brushstrokes lending a powerful texture and vitality to the sky and petals. The deep blue of the sky creates a striking contrast with the bright sunflowers, cultivating an atmosphere that is both serene and hopeful. Swarms of bees dance and hover ceaselessly among the flowers, busily collecting nectar. These dynamic bees inject boundless energy and a sense of motion into the otherwise static painting, filling the entire scene with the vibrant life of summer and the joy of harvest." | |
| ] | |
| with gr.Blocks( | |
| title="T2V UI", | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue", neutral_hue="slate"), | |
| ) as demo: | |
| st_token = gr.State() | |
| gr.Markdown("# Wan2.2 T2V running on AMD MI300x") | |
| gr.Markdown("### Prompt") | |
| text = gr.Textbox(label="Prompt", placeholder="Describe the video you want", lines=3) | |
| gr.Examples(examples=sample_prompts, inputs=text) | |
| gr.Markdown( | |
| "[More Prompts](https://alidocs.dingtalk.com/i/nodes/EpGBa2Lm8aZxe5myC99MelA2WgN7R35y)" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| width = gr.Number(label="Width", value=1280, precision=0) | |
| height = gr.Number(label="Height", value=704, precision=0) | |
| with gr.Row(): | |
| length = gr.Number(label="Frames", value=121, precision=0) | |
| fps = gr.Number(label="FPS", value=24, precision=0) | |
| with gr.Row(): | |
| steps = gr.Number(label="Steps", value=20, precision=0) | |
| cfg = gr.Number(label="CFG", value=5.0) | |
| seed = gr.Number(label="Seed (optional)", value=None) | |
| filename_prefix = gr.Textbox(label="Filename prefix", value="video/ComfyUI") | |
| run_btn = gr.Button("Generate", variant="primary") | |
| prog_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, step=1, interactive=False) | |
| out_video = gr.Video(label="Result") | |
| def _init_token(): | |
| return str(uuid.uuid4()) | |
| demo.load(_init_token, outputs=st_token) | |
| def generate_fn(text, width, height, length, fps, steps, cfg, seed, filename_prefix, token): | |
| req = T2VReq( | |
| token=token, | |
| text=text, | |
| seed=int(seed) if seed is not None else None, | |
| steps=int(steps) if steps is not None else None, | |
| cfg=float(cfg) if cfg is not None else None, | |
| width=int(width) if width is not None else None, | |
| height=int(height) if height is not None else None, | |
| length=int(length) if length is not None else None, | |
| fps=int(fps) if fps is not None else None, | |
| filename_prefix=filename_prefix if filename_prefix else None, | |
| ) | |
| prompt = _inject_params(WORKFLOW_TEMPLATE, req) | |
| client_id = str(uuid.uuid4()) | |
| ws = _open_ws(client_id, req.token) | |
| prompt_id = _queue_prompt(prompt, client_id, req.token) | |
| total_nodes = max(1, len(prompt)) | |
| seen = set() | |
| p = 0 | |
| last_emit = -1 | |
| start = time.time() | |
| ws.settimeout(60) | |
| while True: | |
| out = ws.recv() | |
| if isinstance(out, (bytes, bytearray)): | |
| if p < 95 and time.time() - start > 2: | |
| p = min(95, p + 1) | |
| if p != last_emit: | |
| last_emit = p | |
| yield p, None | |
| continue | |
| msg = json.loads(out) | |
| if msg.get("type") == "executing": | |
| data = msg.get("data", {}) | |
| if data.get("prompt_id") != prompt_id: | |
| continue | |
| node = data.get("node") | |
| if node is None: | |
| break | |
| if node not in seen: | |
| seen.add(node) | |
| p = min(99, int(len(seen) / total_nodes * 100)) | |
| if p != last_emit: | |
| last_emit = p | |
| yield p, None | |
| ws.close() | |
| hist = _get_history(prompt_id, req.token) | |
| info = _extract_video_from_history(hist) | |
| q = urlencode(info) | |
| video_url = f"http://{COMFY_HOST}/view?{q}&token={req.token}" | |
| yield 100, video_url | |
| run_btn.click( | |
| generate_fn, | |
| inputs=[text, width, height, length, fps, steps, cfg, seed, filename_prefix, st_token], | |
| outputs=[prog_bar, out_video] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |