Wan2GP / shared /deepy /controller.py
Egnalkram's picture
Upload folder using huggingface_hub
4689c2b verified
from __future__ import annotations
import json
import secrets
import time
import traceback
from dataclasses import dataclass
from typing import Any, Callable
import gradio as gr
from shared.deepy.config import (
DEEPY_ENABLED_KEY,
DEEPY_VRAM_MODE_KEY,
DEEPY_VRAM_MODE_UNLOAD,
deepy_available,
deepy_requirement_met,
normalize_deepy_enabled,
normalize_deepy_vram_mode,
set_deepy_runtime_config,
)
from shared.deepy import ui_settings as deepy_ui_settings
from shared.deepy.engine import (
AssistantEngine,
AssistantRuntimeHooks,
begin_assistant_turn,
clear_assistant_session,
get_or_create_assistant_session,
request_assistant_interrupt,
request_assistant_reset,
set_assistant_debug,
set_assistant_tool_ui_settings,
tools as AssistantTools,
)
from shared.gradio import assistant_chat
from shared.utils.thread_utils import AsyncStream, async_run_in
_DEEPY_GPU_PROCESS_ID = "deepy"
_DEEPY_REQUIREMENT_TEXT = "Deepy requires Prompt Enhancer to be set to Qwen3.5VL Abliterated 4B or 9B."
_DEEPY_DISABLED_TEXT = "Deepy is disabled in Configuration > Deepy."
@dataclass(slots=True)
class DeepyDeps:
get_server_config: Callable[[], dict[str, Any]]
get_server_config_filename: Callable[[], str]
get_verbose_level: Callable[[], int]
resolve_prompt_enhancer_settings: Callable[..., tuple[Any, int]]
get_state_model_type: Callable[[Any], str]
get_model_def: Callable[[str], Any]
ensure_prompt_enhancer_loaded: Callable[..., tuple[Any, Any]]
unload_prompt_enhancer_runtime: Callable[[], None]
get_image_caption_model: Callable[[], Any]
get_image_caption_processor: Callable[[], Any]
get_enhancer_offloadobj: Callable[[], Any]
acquire_gpu: Callable[[Any], None]
release_gpu: Callable[..., None]
register_gpu_resident: Callable[..., None]
clear_gpu_resident: Callable[[Any], None]
get_new_refresh_id: Callable[[], Any]
get_gen_info: Callable[[Any], dict[str, Any]]
get_processed_queue: Callable[[dict[str, Any]], tuple[list[Any], list[Any], list[Any], list[Any]]]
get_output_filepath: Callable[[str, bool, bool], str]
record_file_metadata: Callable[..., Any]
exec_prompt_enhancer_engine: Callable[..., Any]
clear_queue_action: Callable[[Any], Any]
def _unload_prompt_enhancer_runtime(prompt_enhancer_image_caption_model, prompt_enhancer_llm_model) -> None:
from shared.prompt_enhancer import unload_prompt_enhancer_models
unload_prompt_enhancer_models(prompt_enhancer_image_caption_model, prompt_enhancer_llm_model)
class DeepyController:
def __init__(self, deps: DeepyDeps):
self._deps = deps
def get_verbose_level(self) -> int:
try:
return int(self._deps.get_verbose_level() or 0)
except Exception:
return 0
def _sync_debug_enabled(self) -> bool:
try:
debug_enabled = int(self._deps.get_verbose_level() or 0) >= 2
except Exception:
debug_enabled = False
set_assistant_debug(debug_enabled)
return debug_enabled
def _server_config(self) -> dict[str, Any]:
return self._deps.get_server_config() or {}
def is_available(self) -> bool:
return deepy_available(self._server_config())
def requirement_error_text(self) -> str:
server_config = self._server_config()
if not deepy_requirement_met(server_config):
return _DEEPY_REQUIREMENT_TEXT
if not normalize_deepy_enabled(server_config.get(DEEPY_ENABLED_KEY, 0)):
return _DEEPY_DISABLED_TEXT
return ""
def get_vram_mode(self) -> str:
server_config = self._server_config()
return normalize_deepy_vram_mode(server_config.get(DEEPY_VRAM_MODE_KEY, DEEPY_VRAM_MODE_UNLOAD))
def _ensure_vision_loaded(self, override_profile=None):
self._deps.ensure_prompt_enhancer_loaded(override_profile=override_profile)
image_caption_model = self._deps.get_image_caption_model()
image_caption_processor = self._deps.get_image_caption_processor()
if image_caption_model is None or image_caption_processor is None:
raise gr.Error("Prompt enhancer vision runtime is not available.")
return image_caption_model, image_caption_processor
def _unload_weights(self) -> None:
enhancer_offloadobj = self._deps.get_enhancer_offloadobj()
if enhancer_offloadobj is not None:
enhancer_offloadobj.unload_all()
def _build_preload_release_callback(self) -> Callable[[], None]:
def _release_preloaded_runtime() -> None:
try:
self._deps.unload_prompt_enhancer_runtime()
finally:
self._unload_weights()
return _release_preloaded_runtime
def release_vram(self, state, clear_session_state = False, discard_runtime_snapshot = False):
session = get_or_create_assistant_session(state)
release_callback = session.release_vram_callback
session.release_vram_callback = None
session.discard_runtime_snapshot_on_release = bool(discard_runtime_snapshot)
self._deps.clear_gpu_resident(state)
try:
if callable(release_callback):
release_callback()
finally:
if discard_runtime_snapshot:
session.runtime_snapshot = None
if len(session.rendered_token_ids) == 0:
session.pending_replay_reason = ""
session.discard_runtime_snapshot_on_release = False
if clear_session_state:
clear_assistant_session(session)
def preload_cli_runtime(self, state, override_profile=None) -> dict[str, Any]:
self._sync_debug_enabled()
self._deps.clear_gpu_resident(state)
self._deps.acquire_gpu(state)
keep_resident = False
warmed_vllm = False
try:
model, _tokenizer = self._deps.ensure_prompt_enhancer_loaded(override_profile=override_profile)
from shared.prompt_enhancer import qwen35_text
if qwen35_text._use_vllm_prompt_enhancer(model):
engine = qwen35_text._get_or_create_vllm_engine(model, usage_mode="assistant")
engine.reserve_runtime(prompt_len=64, max_tokens=1, cfg_scale=1.0)
engine._ensure_llm()
llm = getattr(engine, "_llm", None)
if llm is None:
raise RuntimeError("Assistant NanoVLLM runtime is not available.")
llm.model_runner.ensure_runtime_ready()
engine.release_runtime_allocations()
warmed_vllm = True
keep_resident = True
return {"status": "ready", "warmed_vllm": warmed_vllm}
finally:
self._deps.release_gpu(
state,
keep_resident=keep_resident,
release_vram_callback=self._build_preload_release_callback() if keep_resident else None,
force_release_on_acquire=True,
)
def update_tool_ui_settings(self, state, *, auto_cancel_queue_tasks=None, use_template_properties=None, width=None, height=None, num_frames=None, seed=None, video_with_speech_variant=None, image_generator_variant=None, image_editor_variant=None, video_generator_variant=None, speech_from_description_variant=None, speech_from_sample_variant=None, persist=False):
session = get_or_create_assistant_session(state)
normalized = set_assistant_tool_ui_settings(
session,
auto_cancel_queue_tasks=auto_cancel_queue_tasks,
use_template_properties=use_template_properties,
width=width,
height=height,
num_frames=num_frames,
seed=seed,
video_with_speech_variant=video_with_speech_variant,
image_generator_variant=image_generator_variant,
image_editor_variant=image_editor_variant,
video_generator_variant=video_generator_variant,
speech_from_description_variant=speech_from_description_variant,
speech_from_sample_variant=speech_from_sample_variant,
)
if persist:
server_config = self._server_config()
server_config_filename = str(self._deps.get_server_config_filename() or "").strip()
if deepy_ui_settings.store_assistant_tool_ui_settings(server_config, normalized):
set_deepy_runtime_config(server_config, server_config_filename)
if len(server_config_filename) > 0:
with open(server_config_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(server_config, indent=4))
return normalized
def persist_auto_cancel_queue_tasks(self, state, auto_cancel_queue_tasks):
session = get_or_create_assistant_session(state)
current = dict(session.tool_ui_settings or deepy_ui_settings.normalize_assistant_tool_ui_settings())
current["auto_cancel_queue_tasks"] = auto_cancel_queue_tasks
normalized = deepy_ui_settings.normalize_assistant_tool_ui_settings(**current)
session.tool_ui_settings = dict(normalized)
server_config = self._server_config()
server_config_filename = str(self._deps.get_server_config_filename() or "").strip()
if deepy_ui_settings.store_assistant_tool_ui_settings(server_config, normalized):
set_deepy_runtime_config(server_config, server_config_filename)
if len(server_config_filename) > 0:
with open(server_config_filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(server_config, indent=4))
return normalized["auto_cancel_queue_tasks"]
def store_selected_video_time(self, state, current_time):
gen = self._deps.get_gen_info(state)
try:
value = float(current_time)
except Exception:
value = None
gen["selected_video_time"] = None if value is None or value < 0 else value
def create_tools(self, state, send_cmd, session = None):
active_session = get_or_create_assistant_session(state) if session is None else session
gen = self._deps.get_gen_info(state)
return AssistantTools(
gen,
self._deps.get_processed_queue,
send_cmd,
session=active_session,
get_output_filepath=self._deps.get_output_filepath,
record_file_metadata=self._deps.record_file_metadata,
get_server_config=self._server_config,
)
def run_assistant_prompt_turn(self, state, model_def, prompt_enhancer_modes, original_prompts, seed, override_profile=None, send_cmd=None, tools=None) -> None:
debug_enabled = self._sync_debug_enabled()
server_config = self._server_config()
if not normalize_deepy_enabled(server_config.get(DEEPY_ENABLED_KEY, 0)):
raise gr.Error(_DEEPY_DISABLED_TEXT)
if not deepy_requirement_met(server_config):
raise gr.Error(_DEEPY_REQUIREMENT_TEXT)
if send_cmd is None or tools is None:
raise gr.Error("Assistant mode requires a command stream and a tool registry.")
enhancer_temperature = server_config.get("prompt_enhancer_temperature", 0.6)
enhancer_top_p = server_config.get("prompt_enhancer_top_p", 0.9)
randomize_seed = server_config.get("prompt_enhancer_randomize_seed", True)
assistant_seed = secrets.randbits(32) if randomize_seed else (seed if seed is not None and seed >= 0 else 0)
session = get_or_create_assistant_session(state)
assistant_model_def = model_def
_assistant_instructions, assistant_max_new_tokens = self._deps.resolve_prompt_enhancer_settings(assistant_model_def, prompt_enhancer_modes, is_image=False, text_encoder_max_tokens=1024)
assistant = AssistantEngine(
session,
AssistantRuntimeHooks(
acquire_gpu=lambda: self._deps.acquire_gpu(state),
release_gpu=lambda keep_resident = False, release_vram_callback = None, force_release_on_acquire = True: self._deps.release_gpu(state, keep_resident=keep_resident, release_vram_callback=release_vram_callback, force_release_on_acquire=force_release_on_acquire),
register_gpu_resident=lambda release_vram_callback = None, force_release_on_acquire = True: self._deps.register_gpu_resident(state, release_vram_callback=release_vram_callback, force_release_on_acquire=force_release_on_acquire),
clear_gpu_resident=lambda: self._deps.clear_gpu_resident(state),
ensure_loaded=lambda: self._deps.ensure_prompt_enhancer_loaded(override_profile=override_profile),
unload_runtime=self._deps.unload_prompt_enhancer_runtime,
unload_weights=self._unload_weights,
ensure_vision_loaded=lambda: self._ensure_vision_loaded(override_profile=override_profile),
),
tools,
send_cmd,
debug_enabled=debug_enabled,
thinking_enabled="K" in prompt_enhancer_modes,
vram_mode=self.get_vram_mode(),
)
assistant.run_turn(
original_prompts[0] if len(original_prompts) > 0 else "",
max_new_tokens=max(1024, int(assistant_max_new_tokens)),
seed=assistant_seed,
do_sample=True,
temperature=enhancer_temperature,
top_p=enhancer_top_p,
)
def ask_ai(self, state, ask_request):
self._sync_debug_enabled()
def get_refresh_id():
return str(time.time()) + "_" + str(self._deps.get_new_refresh_id())
def drain_chat_output_batch(first_payload):
payloads = [first_payload]
while True:
next_item = com_stream.output_queue.top()
if not isinstance(next_item, tuple) or len(next_item) < 1 or next_item[0] != "chat_output":
break
_cmd, next_payload = com_stream.output_queue.pop()
payloads.append(next_payload)
return assistant_chat.build_event_batch(payloads)
session = get_or_create_assistant_session(state)
ask_request = str(ask_request or "").strip()
if len(ask_request) == 0:
yield gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
return
if not self.is_available():
error_turn_id = assistant_chat.create_assistant_turn(session)
error_event = assistant_chat.set_assistant_content(session, error_turn_id, self.requirement_error_text())
yield error_event if error_event is not None else gr.update(), gr.update(), gr.update(value=""), gr.update(), gr.update()
return
gen = self._deps.get_gen_info(state)
com_stream = AsyncStream()
send_cmd = com_stream.output_queue.push
queued = session.worker_active or session.queued_job_count > 0
queued_epoch = session.chat_epoch
session.queued_job_count += 1
user_message_id, _user_event = assistant_chat.add_user_message(session, ask_request, queued=queued)
yield assistant_chat.build_sync_event(session), gr.update(), gr.update(value=""), gr.update(), gr.update()
if queued:
yield assistant_chat.build_status_event("Queued behind the current assistant task.", kind="queued"), gr.update(), gr.update(), gr.update(), gr.update()
def queue_worker_func():
session.queued_job_count = max(0, session.queued_job_count - 1)
if queued_epoch != session.chat_epoch:
send_cmd("exit", None)
return
session.interrupt_requested = False
session.control_queue = com_stream.output_queue
session.worker_active = True
begin_assistant_turn(session, user_message_id, ask_request)
send_cmd("chat_output", assistant_chat.build_sync_event(session))
queued_badge_event = assistant_chat.set_message_badge(session, user_message_id, None)
if queued_badge_event is not None:
send_cmd("chat_output", queued_badge_event)
my_tools = self.create_tools(state, send_cmd, session=session)
try:
self._deps.exec_prompt_enhancer_engine(state, None, "AK", [ask_request], None, None, False, False, 0, None, 3.5, send_cmd, my_tools)
except Exception as e:
traceback.print_exc()
error_turn_id = assistant_chat.create_assistant_turn(session)
error_event = assistant_chat.set_assistant_content(session, error_turn_id, f"Assistant crashed: {e}")
if error_event is not None:
send_cmd("chat_output", error_event)
send_cmd("chat_output", assistant_chat.build_status_event(None, visible=False))
finally:
session.worker_active = False
if session.control_queue is com_stream.output_queue:
session.control_queue = None
if queued_epoch == session.chat_epoch:
send_cmd("chat_output", assistant_chat.build_sync_event(session))
session.interrupt_requested = False
send_cmd("exit", None)
async_run_in("assistant", queue_worker_func)
while True:
cmd, data = com_stream.output_queue.next()
if cmd == "console_output":
print(data)
elif cmd == "chat_output":
yield drain_chat_output_batch(data), gr.update(), gr.update(), gr.update(), gr.update()
elif cmd == "load_queue_trigger":
yield gr.update(), str(get_refresh_id()), gr.update(), gr.update(), gr.update()
elif cmd == "abort_client_id":
yield gr.update(), gr.update(), gr.update(), gr.update(), str(data or "")
elif cmd == "refresh_gallery":
yield gr.update(), gr.update(), gr.update(), str(get_refresh_id()), gr.update()
elif cmd == "error":
error_turn_id = assistant_chat.create_assistant_turn(session)
error_event = assistant_chat.set_assistant_content(session, error_turn_id, str(data or "Assistant error."))
yield error_event if error_event is not None else gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
elif cmd == "exit":
break
def stop_ai(self, state):
session = get_or_create_assistant_session(state)
if not session.worker_active:
return gr.update(), gr.update(), gr.update(), gr.update()
request_assistant_interrupt(session)
return assistant_chat.build_status_event(None, visible=False), gr.update(), gr.update(), gr.update()
def reset_ai(self, state):
session = get_or_create_assistant_session(state)
if session.worker_active:
request_assistant_reset(session)
assistant_chat.reset_session_chat(session)
else:
self.release_vram(state, True)
session.chat_html = ""
return assistant_chat.build_reset_event(), gr.update(), gr.update(value=""), gr.update()
def create_controller(**deps_kwargs) -> DeepyController:
return DeepyController(DeepyDeps(**deps_kwargs))