from pydantic import BaseModel, Field from typing import Optional, Dict, Any import json, uuid, time, os import requests import websocket from urllib.parse import urlencode import gradio as gr COMFY_HOST = os.getenv("COMFY_HOST", "134.199.132.159") with open("workflow.json", "r", encoding="utf-8") as f: WORKFLOW_TEMPLATE: Dict[str, Any] = json.load(f) class T2VReq(BaseModel): token: str = Field(...) text: str = Field(...) negative: Optional[str] = None seed: Optional[int] = None steps: Optional[int] = 4 cfg: Optional[float] = 1 width: Optional[int] = 640 height: Optional[int] = 640 length: Optional[int] = 160 fps: Optional[int] = 10 filename_prefix: Optional[str] = "video/ComfyUI" def _inject_params(prompt: Dict[str, Any], r: T2VReq) -> Dict[str, Any]: p = json.loads(json.dumps(prompt)) p["89"]["inputs"]["text"] = r.text # if r.seed is None: # r.seed = int.from_bytes(os.urandom(8), "big") & ((1 << 63) - 1) # p["3"]["inputs"]["seed"] = r.seed # if r.steps is not None: p["78"]["inputs"]["steps"] = r.steps # if r.cfg is not None: p["78"]["inputs"]["cfg"] = r.cfg if r.width is not None: p["74"]["inputs"]["width"] = r.width if r.height is not None: p["74"]["inputs"]["height"] = r.height if r.length is not None: p["74"]["inputs"]["length"] = r.length if r.fps is not None: p["88"]["inputs"]["fps"] = r.fps if r.filename_prefix: p["80"]["inputs"]["filename_prefix"] = r.filename_prefix return p def _open_ws(client_id: str, token: str): ws = websocket.WebSocket() ws.connect(f"ws://{COMFY_HOST}/ws?clientId={client_id}&token={token}", timeout=1800) return ws def _queue_prompt(prompt: Dict[str, Any], client_id: str, token: str) -> str: payload = {"prompt": prompt, "client_id": client_id} resp = requests.post(f"http://{COMFY_HOST}/prompt?token={token}", json=payload, timeout=1800) if resp.status_code != 200: raise RuntimeError(f"ComfyUI /prompt err: {resp.text}") data = resp.json() if "prompt_id" not in data: raise RuntimeError(f"/prompt no prompt_id: {data}") return data["prompt_id"] def _get_history(prompt_id: str, token: str) -> Dict[str, Any]: r = requests.get(f"http://{COMFY_HOST}/history/{prompt_id}?token={token}", timeout=1800) r.raise_for_status() hist = r.json() return hist.get(prompt_id, {}) def _extract_video_from_history(history: Dict[str, Any]) -> Dict[str, str]: outputs = history.get("outputs", {}) for _, node_out in outputs.items(): if "images" in node_out: for it in node_out["images"]: if all(k in it for k in ("filename", "subfolder", "type")): fn = it["filename"] if fn.lower().endswith((".mp4", ".webm", ".gif", ".mov", ".mkv")): return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} for key in ("videos", "files"): if key in node_out and node_out[key]: it = node_out[key][0] if all(k in it for k in ("filename", "subfolder", "type")): return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]} raise RuntimeError("No video file found in history outputs") # sample_prompts = [ # "A golden retriever running across a beach at sunset, cinematic", # "A cyberpunk city street at night with neon lights, light rain, slow pan", # "An astronaut walking on an alien planet covered in glowing crystals, purple sky with two moons, dust particles floating, slow panning shot, highly detailed, cinematic atmosphere.", # "A cat gracefully jumping between rooftops in slow motion, warm sunset lighting, camera tracking the cat midair, cinematic composition, natural movement." #] with gr.Blocks( title="T2V UI", theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue", neutral_hue="slate"), ) as demo: # st_token = gr.State() gr.Markdown("# Experience Wan2.2 14B Text-to-Video on AMD MI300X — Free Trial") gr.Markdown("Powered by [AMD Devcloud](https://oneclickamd.ai/) and [ComfyUI](https://github.com/comfyanonymous/ComfyUI)") gr.Markdown("### Prompt") text = gr.Textbox(label="Prompt", placeholder="Describe the video you want", lines=3) # gr.Examples(examples=sample_prompts, inputs=text) with gr.Row(): width = gr.Number(label="Width", value=640, precision=0, step=8) height = gr.Number(label="Height", value=640, precision=0, step=8) # with gr.Row(): length = gr.Number(label="Frames", value=160, precision=0) fps = gr.Number(label="FPS", value=10, precision=0) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): steps = gr.Number(label="Steps", value=4, precision=0) cfg = gr.Number(label="CFG", value=5.0) seed = gr.Number(label="Seed (optional)", value=None) with gr.Row(): filename_prefix = gr.Textbox(label="Filename prefix", value="video/ComfyUI") st_token = gr.Textbox(label="token", placeholder="name") with gr.Row(): with gr.Column(scale=1): run_btn = gr.Button("Generate", variant="primary", scale=1) prog_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, step=1, interactive=False) with gr.Column(scale=1): out_video = gr.Video(label="Result", height=480) def _init_token(): return str(uuid.uuid4()) demo.load(_init_token, outputs=st_token) def generate_fn(text, width, height, length, fps, steps, cfg, seed, filename_prefix, token): req = T2VReq( token=token, text=text, seed=int(seed) if seed is not None else None, steps=int(steps) if steps is not None else None, cfg=float(cfg) if cfg is not None else None, width=int(width) if width is not None else None, height=int(height) if height is not None else None, length=int(length) if length is not None else None, fps=int(fps) if fps is not None else None, filename_prefix=filename_prefix if filename_prefix else None, ) prompt = _inject_params(WORKFLOW_TEMPLATE, req) client_id = str(uuid.uuid4()) ws = _open_ws(client_id, req.token) prompt_id = _queue_prompt(prompt, client_id, req.token) total_nodes = max(1, len(prompt)) seen = set() p = 0 last_emit = -1 start = time.time() ws.settimeout(180) while True: out = ws.recv() if isinstance(out, (bytes, bytearray)): if p < 95 and time.time() - start > 2: p = min(95, p + 1) if p != last_emit: last_emit = p yield p, None continue msg = json.loads(out) if msg.get("type") == "executing": data = msg.get("data", {}) if data.get("prompt_id") != prompt_id: continue node = data.get("node") if node is None: break if node not in seen: seen.add(node) p = min(99, int(len(seen) / total_nodes * 100)) if p != last_emit: last_emit = p yield p, None ws.close() hist = _get_history(prompt_id, req.token) info = _extract_video_from_history(hist) q = urlencode(info) video_url = f"http://{COMFY_HOST}/view?{q}&token={req.token}" yield 100, video_url run_btn.click( generate_fn, inputs=[text, width, height, length, fps, steps, cfg, seed, filename_prefix, st_token], outputs=[prog_bar, out_video] ) if __name__ == "__main__": demo.queue().launch()