"""Lightweight in-process API wrapper around WanGP generation.""" from __future__ import annotations import contextlib import copy import importlib import inspect import io import json import os import queue import sys import threading import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Iterator, Sequence from PIL import Image from shared.utils.process_locks import set_main_generation_running from shared.utils.thread_utils import AsyncStream _RUNTIME_LOCK = threading.RLock() _GENERATION_LOCK = threading.RLock() _RUNTIME: "_WanGPRuntime | None" = None _BANNER_PRINTED = False @dataclass(frozen=True) class StreamMessage: stream: str text: str @dataclass(frozen=True) class ProgressUpdate: phase: str status: str progress: int current_step: int | None total_steps: int | None raw_phase: str | None = None unit: str | None = None @dataclass(frozen=True) class PreviewUpdate: image: Image.Image | None phase: str status: str progress: int current_step: int | None total_steps: int | None @dataclass(frozen=True) class SessionEvent: kind: str data: Any = None timestamp: float = field(default_factory=time.time) @dataclass(frozen=True) class GenerationResult: success: bool generated_files: list[str] errors: list["GenerationError"] total_tasks: int successful_tasks: int failed_tasks: int @dataclass(frozen=True) class GenerationError: message: str task_index: int | None = None task_id: Any = None stage: str | None = None def __str__(self) -> str: return self.message class SessionStream: def __init__(self) -> None: self._queue: queue.Queue[SessionEvent | object] = queue.Queue() self._closed = threading.Event() self._sentinel = object() def put(self, kind: str, data: Any = None) -> None: if self._closed.is_set(): return self._queue.put(SessionEvent(kind=kind, data=data)) def close(self) -> None: if self._closed.is_set(): return self._closed.set() self._queue.put(self._sentinel) def get(self, timeout: float | None = None) -> SessionEvent | None: try: item = self._queue.get(timeout=timeout) except queue.Empty: return None if item is self._sentinel: return None return item def iter(self, timeout: float | None = None) -> Iterator[SessionEvent]: while True: event = self.get(timeout=timeout) if event is None: if self._closed.is_set(): break continue yield event @property def closed(self) -> bool: return self._closed.is_set() class _OutputCapture(io.TextIOBase): def __init__( self, stream_name: str, emit_line, console: io.TextIOBase | None = None, *, console_isatty: bool = True, ) -> None: self._stream_name = stream_name self._emit_line = emit_line self._console = console self._console_isatty = bool(console_isatty) self._buffer = "" def writable(self) -> bool: return True @property def encoding(self) -> str: return str(getattr(self._console, "encoding", "utf-8")) def isatty(self) -> bool: return self._console_isatty def write(self, text: str) -> int: if not text: return 0 if self._console is not None: self._console.write(text) self._buffer += text self._drain(False) return len(text) def flush(self) -> None: if self._console is not None: self._console.flush() self._drain(True) def _drain(self, flush_all: bool) -> None: while True: split_at = -1 for delimiter in ("\r", "\n"): index = self._buffer.find(delimiter) if index >= 0 and (split_at < 0 or index < split_at): split_at = index if split_at < 0: break line = self._buffer[:split_at] self._buffer = self._buffer[split_at + 1 :] if line: self._emit_line(self._stream_name, line) if flush_all and self._buffer: self._emit_line(self._stream_name, self._buffer) self._buffer = "" @dataclass(frozen=True) class _WanGPRuntime: module: Any root: Path config_path: Path cli_args: tuple[str, ...] class SessionJob: def __init__(self, session: "WanGPSession") -> None: self._session = session self.events = SessionStream() self._done = threading.Event() self._cancel_requested = threading.Event() self._thread: threading.Thread | None = None self._result: GenerationResult | None = None def _bind_thread(self, thread: threading.Thread) -> None: self._thread = thread def _set_result(self, result: GenerationResult) -> None: self._result = result self._done.set() def cancel(self) -> None: self._cancel_requested.set() def result(self, timeout: float | None = None) -> GenerationResult: if not self._done.wait(timeout=timeout): raise TimeoutError("WanGP session job timed out") return self._result or GenerationResult( success=False, generated_files=[], errors=[], total_tasks=0, successful_tasks=0, failed_tasks=0, ) def join(self, timeout: float | None = None) -> GenerationResult: return self.result(timeout=timeout) @property def done(self) -> bool: return self._done.is_set() @property def cancel_requested(self) -> bool: return self._cancel_requested.is_set() class WanGPSession: def __init__( self, *, root: str | os.PathLike[str] | None = None, config_path: str | os.PathLike[str] | None = None, output_dir: str | os.PathLike[str] | None = None, callbacks: object | None = None, cli_args: Sequence[str] = (), console_output: bool = True, console_isatty: bool = True, ) -> None: self._root = Path(root or Path(__file__).resolve().parents[1]).resolve() self._config_path = Path(config_path).resolve() if config_path is not None else (self._root / "wgp_config.json").resolve() self._output_dir = Path(output_dir).resolve() if output_dir is not None else None self._callbacks = callbacks self._cli_args = tuple(str(arg) for arg in cli_args) self._console_output = bool(console_output) self._console_isatty = bool(console_isatty) self._state = self._create_headless_state() self._active_job: SessionJob | None = None self._job_lock = threading.Lock() self._attachment_keys: tuple[str, ...] | None = None def ensure_ready(self) -> "WanGPSession": self._ensure_runtime() return self def submit(self, source: str | os.PathLike[str] | dict[str, Any] | list[dict[str, Any]]) -> SessionJob: tasks = self._normalize_source(source, caller_base_path=self._get_caller_base_path()) return self._submit_tasks(tasks) def submit_task(self, settings: dict[str, Any]) -> SessionJob: caller_base_path = self._get_caller_base_path() task = self._normalize_task(settings, task_index=1) return self._submit_tasks([self._absolutize_task_paths(task, caller_base_path)]) def submit_manifest(self, settings_list: list[dict[str, Any]]) -> SessionJob: caller_base_path = self._get_caller_base_path() tasks = [ self._absolutize_task_paths(self._normalize_task(settings, task_index=index + 1), caller_base_path) for index, settings in enumerate(settings_list) ] return self._submit_tasks(tasks) def run(self, source: str | os.PathLike[str] | dict[str, Any] | list[dict[str, Any]]) -> GenerationResult: return self.submit(source).result() def run_task(self, settings: dict[str, Any]) -> GenerationResult: return self.submit_task(settings).result() def run_manifest(self, settings_list: list[dict[str, Any]]) -> GenerationResult: return self.submit_manifest(settings_list).result() def close(self) -> None: runtime = self._ensure_runtime() with _GENERATION_LOCK, _pushd(runtime.root): runtime.module.release_model() def cancel(self) -> None: with self._job_lock: job = self._active_job if job is not None: job.cancel() @staticmethod def _create_headless_state() -> dict[str, Any]: return { "gen": { "queue": [], "in_progress": False, "file_list": [], "file_settings_list": [], "audio_file_list": [], "audio_file_settings_list": [], "selected": 0, "audio_selected": 0, "prompt_no": 0, "prompts_max": 0, "repeat_no": 0, "total_generation": 1, "window_no": 0, "total_windows": 0, "progress_status": "", "process_status": "process:main", }, "loras": [], } def _submit_tasks(self, tasks: list[dict[str, Any]]) -> SessionJob: with self._job_lock: if self._active_job is not None and not self._active_job.done: raise RuntimeError("WanGP session already has a generation in progress") job = SessionJob(self) thread = threading.Thread( target=self._run_job, args=(job, copy.deepcopy(tasks)), daemon=True, name="wangp-session-job", ) job._bind_thread(thread) self._active_job = job thread.start() return job def _run_job(self, job: SessionJob, tasks: list[dict[str, Any]]) -> None: stream = AsyncStream() gen = self._state["gen"] worker_done = threading.Event() base_file_count = len(gen["file_list"]) base_audio_count = len(gen["audio_file_list"]) total_tasks = len(tasks) runtime: _WanGPRuntime | None = None task_summary: dict[str, Any] = { "errors": [], "successful_tasks": 0, "failed_tasks": 0, "total_tasks": total_tasks, } try: runtime = self._ensure_runtime() with _GENERATION_LOCK, _pushd(runtime.root): self._configure_runtime(runtime) self._prepare_state_for_run(tasks) job.events.put("started", {"tasks": len(tasks)}) def worker() -> None: stdout_capture = _OutputCapture( "stdout", lambda stream_name, line: self._emit_stream(job, stream_name, line), console=sys.__stdout__ if self._console_output else None, console_isatty=self._console_isatty, ) stderr_capture = _OutputCapture( "stderr", lambda stream_name, line: self._emit_stream(job, stream_name, line), console=sys.__stderr__ if self._console_output else None, console_isatty=self._console_isatty, ) try: with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture): self._run_tasks_worker(runtime.module, tasks, stream, job, task_summary) except BaseException as exc: failure = self._make_generation_error( exc, task_index=None, task_id=None, stage="runtime", ) task_summary["errors"].append(failure) stream.output_queue.push("error", failure) finally: stdout_capture.flush() stderr_capture.flush() stream.output_queue.push("worker_exit", None) worker_done.set() worker_thread = threading.Thread(target=worker, daemon=True, name="wangp-session-worker") worker_thread.start() while True: if job.cancel_requested: self._request_cancel_unlocked(runtime.module) item = stream.output_queue.pop() if item is None: if worker_done.is_set() and not worker_thread.is_alive(): break time.sleep(0.01) continue command, data = item if command == "worker_exit": break self._handle_command(job, runtime.module, tasks, command, data) worker_thread.join(timeout=0.1) outputs = self._collect_outputs(base_file_count, base_audio_count) if job.cancel_requested and not task_summary["errors"]: task_summary["errors"].append( GenerationError(message="Generation was cancelled", stage="cancelled") ) task_summary["failed_tasks"] = max(task_summary["failed_tasks"], 1) result = GenerationResult( success=not task_summary["errors"], generated_files=outputs, errors=list(task_summary["errors"]), total_tasks=task_summary["total_tasks"], successful_tasks=task_summary["successful_tasks"], failed_tasks=task_summary["failed_tasks"], ) job.events.put("completed", result) self._emit_callback("on_complete", result) job._set_result(result) except BaseException as exc: failure = self._make_generation_error(exc, task_index=None, task_id=None, stage="runtime") result = GenerationResult( success=False, generated_files=[], errors=[failure], total_tasks=total_tasks, successful_tasks=task_summary["successful_tasks"], failed_tasks=max(task_summary["failed_tasks"], 1 if total_tasks > 0 else 0), ) job.events.put("error", failure) self._emit_callback("on_error", failure) job.events.put("completed", result) self._emit_callback("on_complete", result) job._set_result(result) finally: job.events.close() if runtime is not None: self._reset_state_after_run() with self._job_lock: if self._active_job is job: self._active_job = None def _run_tasks_worker( self, wgp, tasks: list[dict[str, Any]], stream: AsyncStream, job: SessionJob, task_summary: dict[str, Any], ) -> None: expected_args = set(inspect.signature(wgp.generate_video).parameters.keys()) total_tasks = len(tasks) for task_index, task in enumerate(tasks, start=1): if job.cancel_requested: break self._state["gen"]["prompt_no"] = task_index self._state["gen"]["prompts_max"] = total_tasks self._state["gen"]["queue"] = tasks task_id = task.get("id") task_errors: list[GenerationError] = [] def send_cmd(command: str, data: Any = None) -> None: if command == "error": failure = self._make_generation_error( data, task_index=task_index, task_id=task_id, stage="generation", ) task_errors.append(failure) stream.output_queue.push("error", failure) return stream.output_queue.push(command, data) validated_settings, validation_error = wgp.validate_task(task, self._state) if validated_settings is None: failure = GenerationError( message=validation_error or f"Task {task_index} failed validation", task_index=task_index, task_id=task_id, stage="validation", ) task_summary["errors"].append(failure) task_summary["failed_tasks"] += 1 stream.output_queue.push("error", failure) continue task_settings = validated_settings.copy() task_settings["state"] = self._state filtered_params = {key: value for key, value in task_settings.items() if key in expected_args} plugin_data = task.get("plugin_data", {}) try: success = wgp.generate_video(task, send_cmd, plugin_data=plugin_data, **filtered_params) except BaseException as exc: if not task_errors: task_errors.append( self._make_generation_error( exc, task_index=task_index, task_id=task_id, stage="generation", ) ) stream.output_queue.push("error", task_errors[-1]) success = False if self._state["gen"].get("abort", False) or job.cancel_requested: task_errors.append( GenerationError( message="Generation was cancelled", task_index=task_index, task_id=task_id, stage="cancelled", ) ) stream.output_queue.push("error", task_errors[-1]) task_summary["errors"].extend(task_errors) task_summary["failed_tasks"] += 1 break if task_errors: task_summary["errors"].extend(task_errors) task_summary["failed_tasks"] += 1 continue if not success: failure = GenerationError( message=f"Task {task_index} did not complete successfully", task_index=task_index, task_id=task_id, stage="generation", ) task_summary["errors"].append(failure) task_summary["failed_tasks"] += 1 stream.output_queue.push("error", failure) continue task_summary["successful_tasks"] += 1 def _handle_command(self, job: SessionJob, wgp, tasks: list[dict[str, Any]], command: str, data: Any) -> None: if command == "progress": progress = self._build_progress_update(data) job.events.put("progress", progress) self._emit_callback("on_progress", progress) return if command == "preview": preview = self._build_preview_update(wgp, tasks, data) if preview is not None: job.events.put("preview", preview) self._emit_callback("on_preview", preview) return if command == "status": text = str(data or "") job.events.put("status", text) self._emit_callback("on_status", text) return if command == "info": text = str(data or "") job.events.put("info", text) self._emit_callback("on_info", text) return if command == "output": job.events.put("output", data) self._emit_callback("on_output", data) return if command == "refresh_models": job.events.put("refresh_models", data) return if command == "error": error = data if isinstance(data, GenerationError) else self._make_generation_error(data) job.events.put("error", error) self._emit_callback("on_error", error) return job.events.put(command, data) def _build_progress_update(self, data: Any) -> ProgressUpdate: current_step: int | None = None total_steps: int | None = None status = "" unit: str | None = None if isinstance(data, list) and data: head = data[0] if isinstance(head, tuple) and len(head) == 2: current_step = int(head[0]) total_steps = int(head[1]) status = str(data[1] if len(data) > 1 else "") if len(data) > 3: unit = str(data[3]) else: status = str(data[1] if len(data) > 1 else head) else: status = str(data or "") raw_phase = None progress_phase = self._state["gen"].get("progress_phase") if isinstance(progress_phase, tuple) and progress_phase: raw_phase = str(progress_phase[0] or "") phase = self._normalize_phase(raw_phase or status) progress = self._estimate_progress(phase, current_step, total_steps) return ProgressUpdate( phase=phase, status=status, progress=progress, current_step=current_step, total_steps=total_steps, raw_phase=raw_phase, unit=unit, ) def _build_preview_update(self, wgp, tasks: list[dict[str, Any]], payload: Any) -> PreviewUpdate | None: progress = self._build_progress_update([0, self._state["gen"].get("progress_status", "")]) model_type = "" queue_tasks = self._state["gen"].get("queue") or tasks if queue_tasks: model_type = str(self._get_task_settings(queue_tasks[0]).get("model_type", "")) image = wgp.generate_preview(model_type, payload) if model_type else None return PreviewUpdate( image=image, phase=progress.phase, status=progress.status, progress=progress.progress, current_step=progress.current_step, total_steps=progress.total_steps, ) def _emit_stream(self, job: SessionJob, stream_name: str, line: str) -> None: message = StreamMessage(stream=stream_name, text=line) job.events.put("stream", message) self._emit_callback("on_stream", message) def _emit_callback(self, method_name: str, payload: Any) -> None: callback = self._callbacks if callback is None: return method = getattr(callback, method_name, None) if callable(method): method(payload) on_event = getattr(callback, "on_event", None) if callable(on_event): on_event(SessionEvent(kind=method_name.removeprefix("on_"), data=payload)) def _configure_runtime(self, runtime: _WanGPRuntime) -> None: runtime.module.server_config["notification_sound_enabled"] = 0 if self._output_dir is not None: self._output_dir.mkdir(parents=True, exist_ok=True) runtime.module.server_config["save_path"] = str(self._output_dir) runtime.module.server_config["image_save_path"] = str(self._output_dir) runtime.module.server_config["audio_save_path"] = str(self._output_dir) runtime.module.save_path = str(self._output_dir) runtime.module.image_save_path = str(self._output_dir) runtime.module.audio_save_path = str(self._output_dir) for output_path in ( runtime.module.save_path, runtime.module.image_save_path, runtime.module.audio_save_path, ): Path(output_path).mkdir(parents=True, exist_ok=True) def _prepare_state_for_run(self, tasks: list[dict[str, Any]]) -> None: gen = self._state["gen"] gen["queue"] = tasks set_main_generation_running(self._state, True) gen["process_status"] = "process:main" gen["progress_status"] = "" gen["progress_phase"] = ("", -1) gen["abort"] = False gen["early_stop"] = False gen["early_stop_forwarded"] = False gen["preview"] = None gen["status"] = "Generating..." gen["in_progress"] = True self._ensure_runtime().module.gen_in_progress = True def _reset_state_after_run(self) -> None: gen = self._state["gen"] gen["queue"] = [] set_main_generation_running(self._state, False) gen["process_status"] = "process:main" gen["progress_status"] = "" gen["progress_phase"] = ("", -1) gen["abort"] = False gen["early_stop"] = False gen["early_stop_forwarded"] = False gen.pop("in_progress", None) self._ensure_runtime().module.gen_in_progress = False def _collect_outputs(self, base_file_count: int, base_audio_count: int) -> list[str]: gen = self._state["gen"] files = gen["file_list"][base_file_count:] audio_files = gen["audio_file_list"][base_audio_count:] return [str(Path(path).resolve()) for path in [*files, *audio_files]] def _request_cancel_unlocked(self, wgp) -> None: gen = self._state["gen"] gen["resume"] = True gen["abort"] = True if wgp.wan_model is not None: wgp.wan_model._interrupt = True def _normalize_source( self, source: str | os.PathLike[str] | dict[str, Any] | list[dict[str, Any]], *, caller_base_path: Path, ) -> list[dict[str, Any]]: if isinstance(source, (str, os.PathLike)): return self._load_tasks_from_path(self._resolve_source_path(Path(source), caller_base_path), caller_base_path) if isinstance(source, list): return [ self._absolutize_task_paths(self._normalize_task(task, task_index=index + 1), caller_base_path) for index, task in enumerate(source) ] if isinstance(source, dict): if isinstance(source.get("tasks"), list): tasks = source["tasks"] return [ self._absolutize_task_paths(self._normalize_task(task, task_index=index + 1), caller_base_path) for index, task in enumerate(tasks) ] return [self._absolutize_task_paths(self._normalize_task(source, task_index=1), caller_base_path)] raise TypeError("WanGP session source must be a path, a settings dict, or a manifest list") def _normalize_task(self, task: dict[str, Any], *, task_index: int) -> dict[str, Any]: if not isinstance(task, dict): raise TypeError(f"Task {task_index} must be a dictionary") normalized = copy.deepcopy(task) if "settings" in normalized and "params" not in normalized: normalized["params"] = normalized.pop("settings") if "params" not in normalized: normalized = {"id": task_index, "params": normalized, "plugin_data": {}} normalized.setdefault("id", task_index) normalized.setdefault("plugin_data", {}) normalized.setdefault("params", {}) settings = normalized["params"] if isinstance(settings, dict): self._normalize_settings_values(settings) normalized.setdefault("prompt", settings.get("prompt", "")) normalized.setdefault("length", settings.get("video_length")) normalized.setdefault("steps", settings.get("num_inference_steps")) normalized.setdefault("repeats", settings.get("repeat_generation", 1)) return normalized @staticmethod def _normalize_settings_values(settings: dict[str, Any]) -> None: force_fps = settings.get("force_fps") if isinstance(force_fps, (int, float)) and not isinstance(force_fps, bool): if isinstance(force_fps, float) and not force_fps.is_integer(): settings["force_fps"] = str(force_fps) else: settings["force_fps"] = str(int(force_fps)) @staticmethod def _get_task_settings(task: dict[str, Any]) -> dict[str, Any]: settings = task.get("params") if isinstance(settings, dict): return settings settings = task.get("settings") if isinstance(settings, dict): return settings return {} def _load_tasks_from_path(self, path: Path, caller_base_path: Path) -> list[dict[str, Any]]: runtime = self._ensure_runtime() if not path.exists(): raise FileNotFoundError(path) if path.suffix.lower() == ".json": return self._load_settings_json(path, caller_base_path) with _pushd(runtime.root): tasks, error = runtime.module._parse_queue_zip(str(path), self._state) if error: raise RuntimeError(error) return [self._normalize_task(task, task_index=index + 1) for index, task in enumerate(tasks)] def _load_settings_json(self, path: Path, caller_base_path: Path) -> list[dict[str, Any]]: with path.open("r", encoding="utf-8") as handle: payload = json.load(handle) if isinstance(payload, list): raw_tasks = payload elif isinstance(payload, dict) and isinstance(payload.get("tasks"), list): raw_tasks = payload["tasks"] elif isinstance(payload, dict): raw_tasks = [payload] else: raise RuntimeError("Settings file must contain a JSON object or a list of tasks") tasks = [self._normalize_task(task, task_index=index + 1) for index, task in enumerate(raw_tasks)] return [self._absolutize_task_paths(task, caller_base_path) for task in tasks] @staticmethod def _get_caller_base_path() -> Path: return Path.cwd().resolve() @staticmethod def _resolve_source_path(path: Path, caller_base_path: Path) -> Path: if path.is_absolute(): return path.resolve() return (caller_base_path / path).resolve() def _absolutize_task_paths(self, task: dict[str, Any], caller_base_path: Path) -> dict[str, Any]: normalized = copy.deepcopy(task) settings = normalized.get("params") if not isinstance(settings, dict): return normalized for key in self._get_attachment_keys(): if key not in settings: continue settings[key] = self._absolutize_setting_path(settings[key], caller_base_path) return normalized def _get_attachment_keys(self) -> tuple[str, ...]: if self._attachment_keys is None: runtime = self._ensure_runtime() keys = getattr(runtime.module, "ATTACHMENT_KEYS", ()) self._attachment_keys = tuple(str(key) for key in keys) return self._attachment_keys def _absolutize_setting_path(self, value: Any, caller_base_path: Path) -> Any: if isinstance(value, list): return [self._absolutize_setting_path(item, caller_base_path) for item in value] if isinstance(value, os.PathLike): value = os.fspath(value) if not isinstance(value, str) or not value.strip(): return value path = Path(value) if path.is_absolute(): return str(path.resolve()) return str((caller_base_path / path).resolve()) @staticmethod def _make_generation_error( error: Any, *, task_index: int | None = None, task_id: Any = None, stage: str | None = None, ) -> GenerationError: if isinstance(error, GenerationError): return error if isinstance(error, BaseException): message = str(error) or error.__class__.__name__ else: message = str(error) return GenerationError(message=message, task_index=task_index, task_id=task_id, stage=stage) def _ensure_runtime(self) -> _WanGPRuntime: global _RUNTIME with _RUNTIME_LOCK: if _RUNTIME is not None: if _RUNTIME.root != self._root or _RUNTIME.config_path != self._config_path or _RUNTIME.cli_args != self._cli_args: raise RuntimeError("WanGP runtime already loaded with different root/config/cli args") return _RUNTIME argv = ["wgp.py", *self._cli_args] default_config_path = (self._root / "wgp_config.json").resolve() if self._config_path.name != "wgp_config.json": raise ValueError("config_path must point to a file named 'wgp_config.json'") if self._config_path != default_config_path: self._config_path.parent.mkdir(parents=True, exist_ok=True) if "--config" not in argv: argv.extend(["--config", str(self._config_path.parent)]) if str(self._root) not in sys.path: sys.path.insert(0, str(self._root)) with _pushd(self._root), _temporary_argv(argv): module = importlib.import_module("wgp") module_root = Path(module.__file__).resolve().parent if module_root != self._root: raise RuntimeError(f"WanGP module already loaded from {module_root}, expected {self._root}") if not hasattr(module, "app"): module.app = module.WAN2GPApplication() module.download_ffmpeg() _RUNTIME = _WanGPRuntime( module=module, root=self._root, config_path=self._config_path, cli_args=self._cli_args, ) _print_banner_once(module) return _RUNTIME @staticmethod def _normalize_phase(text: str | None) -> str: lowered = str(text or "").lower() if "denoising first pass" in lowered or "denoising 1st pass" in lowered: return "inference_stage_1" if "denoising second pass" in lowered or "denoising 2nd pass" in lowered: return "inference_stage_2" if "denoising third pass" in lowered or "denoising 3rd pass" in lowered: return "inference_stage_3" if "loading model" in lowered or lowered.startswith("loading"): return "loading_model" if "enhancing prompt" in lowered or "encoding prompt" in lowered or "encoding" in lowered: return "encoding_text" if "vae decoding" in lowered or "decoding" in lowered: return "decoding" if "saved" in lowered or "completed" in lowered or "output" in lowered: return "downloading_output" if "cancel" in lowered or "abort" in lowered: return "cancelled" return "inference" @staticmethod def _estimate_progress(phase: str, current_step: int | None, total_steps: int | None) -> int: if total_steps is None or total_steps <= 0 or current_step is None: if phase == "loading_model": return 10 if phase == "encoding_text": return 18 if phase == "inference_stage_1": return 25 if phase == "inference_stage_2": return 70 if phase == "inference_stage_3": return 80 if phase == "decoding": return 90 if phase == "downloading_output": return 95 if phase == "cancelled": return 0 return 15 ratio = max(0.0, min(1.0, current_step / total_steps)) if phase == "loading_model": return min(15, 5 + int(ratio * 10)) if phase == "encoding_text": return min(22, 12 + int(ratio * 10)) if phase == "inference_stage_1": return min(68, 20 + int(ratio * 48)) if phase == "inference_stage_2": return min(88, 68 + int(ratio * 20)) if phase == "inference_stage_3": return min(89, 80 + int(ratio * 9)) if phase == "decoding": return min(95, 85 + int(ratio * 10)) if phase == "downloading_output": return min(98, 92 + int(ratio * 6)) if phase == "cancelled": return 0 return min(90, 20 + int(ratio * 65)) def init( *, root: str | os.PathLike[str] | None = None, config_path: str | os.PathLike[str] | None = None, output_dir: str | os.PathLike[str] | None = None, callbacks: object | None = None, cli_args: Sequence[str] = (), console_output: bool = True, ) -> WanGPSession: """Create and eagerly initialize a reusable WanGP session.""" return WanGPSession( root=root, config_path=config_path, output_dir=output_dir, callbacks=callbacks, cli_args=cli_args, console_output=console_output, ).ensure_ready() @contextlib.contextmanager def _pushd(path: Path) -> Iterator[None]: previous = Path.cwd() os.chdir(path) try: yield finally: os.chdir(previous) @contextlib.contextmanager def _temporary_argv(argv: Sequence[str]) -> Iterator[None]: previous = list(sys.argv) sys.argv = list(argv) try: yield finally: sys.argv = previous def _print_banner_once(module) -> None: global _BANNER_PRINTED if _BANNER_PRINTED: return _BANNER_PRINTED = True banner = f"Powered by WanGP v{module.WanGP_version} - a DeepBeepMeep Production\n" console = sys.__stdout__ if sys.__stdout__ is not None else sys.stdout if console is not None: console.write(banner) console.flush()