| from __future__ import annotations
|
|
|
| import json
|
| import secrets
|
| import time
|
| import traceback
|
| from dataclasses import dataclass
|
| from typing import Any, Callable
|
|
|
| import gradio as gr
|
|
|
| from shared.deepy.config import (
|
| DEEPY_ENABLED_KEY,
|
| DEEPY_VRAM_MODE_KEY,
|
| DEEPY_VRAM_MODE_UNLOAD,
|
| deepy_available,
|
| deepy_requirement_met,
|
| normalize_deepy_enabled,
|
| normalize_deepy_vram_mode,
|
| set_deepy_runtime_config,
|
| )
|
| from shared.deepy import ui_settings as deepy_ui_settings
|
| from shared.deepy.engine import (
|
| AssistantEngine,
|
| AssistantRuntimeHooks,
|
| begin_assistant_turn,
|
| clear_assistant_session,
|
| get_or_create_assistant_session,
|
| request_assistant_interrupt,
|
| request_assistant_reset,
|
| set_assistant_debug,
|
| set_assistant_tool_ui_settings,
|
| tools as AssistantTools,
|
| )
|
| from shared.gradio import assistant_chat
|
| from shared.utils.thread_utils import AsyncStream, async_run_in
|
|
|
|
|
| _DEEPY_GPU_PROCESS_ID = "deepy"
|
| _DEEPY_REQUIREMENT_TEXT = "Deepy requires Prompt Enhancer to be set to Qwen3.5VL Abliterated 4B or 9B."
|
| _DEEPY_DISABLED_TEXT = "Deepy is disabled in Configuration > Deepy."
|
|
|
|
|
| @dataclass(slots=True)
|
| class DeepyDeps:
|
| get_server_config: Callable[[], dict[str, Any]]
|
| get_server_config_filename: Callable[[], str]
|
| get_verbose_level: Callable[[], int]
|
| resolve_prompt_enhancer_settings: Callable[..., tuple[Any, int]]
|
| get_state_model_type: Callable[[Any], str]
|
| get_model_def: Callable[[str], Any]
|
| ensure_prompt_enhancer_loaded: Callable[..., tuple[Any, Any]]
|
| unload_prompt_enhancer_runtime: Callable[[], None]
|
| get_image_caption_model: Callable[[], Any]
|
| get_image_caption_processor: Callable[[], Any]
|
| get_enhancer_offloadobj: Callable[[], Any]
|
| acquire_gpu: Callable[[Any], None]
|
| release_gpu: Callable[..., None]
|
| register_gpu_resident: Callable[..., None]
|
| clear_gpu_resident: Callable[[Any], None]
|
| get_new_refresh_id: Callable[[], Any]
|
| get_gen_info: Callable[[Any], dict[str, Any]]
|
| get_processed_queue: Callable[[dict[str, Any]], tuple[list[Any], list[Any], list[Any], list[Any]]]
|
| get_output_filepath: Callable[[str, bool, bool], str]
|
| record_file_metadata: Callable[..., Any]
|
| exec_prompt_enhancer_engine: Callable[..., Any]
|
| clear_queue_action: Callable[[Any], Any]
|
|
|
|
|
| def _unload_prompt_enhancer_runtime(prompt_enhancer_image_caption_model, prompt_enhancer_llm_model) -> None:
|
| from shared.prompt_enhancer import unload_prompt_enhancer_models
|
|
|
| unload_prompt_enhancer_models(prompt_enhancer_image_caption_model, prompt_enhancer_llm_model)
|
|
|
|
|
| class DeepyController:
|
| def __init__(self, deps: DeepyDeps):
|
| self._deps = deps
|
|
|
| def get_verbose_level(self) -> int:
|
| try:
|
| return int(self._deps.get_verbose_level() or 0)
|
| except Exception:
|
| return 0
|
|
|
| def _sync_debug_enabled(self) -> bool:
|
| try:
|
| debug_enabled = int(self._deps.get_verbose_level() or 0) >= 2
|
| except Exception:
|
| debug_enabled = False
|
| set_assistant_debug(debug_enabled)
|
| return debug_enabled
|
|
|
| def _server_config(self) -> dict[str, Any]:
|
| return self._deps.get_server_config() or {}
|
|
|
| def is_available(self) -> bool:
|
| return deepy_available(self._server_config())
|
|
|
| def requirement_error_text(self) -> str:
|
| server_config = self._server_config()
|
| if not deepy_requirement_met(server_config):
|
| return _DEEPY_REQUIREMENT_TEXT
|
| if not normalize_deepy_enabled(server_config.get(DEEPY_ENABLED_KEY, 0)):
|
| return _DEEPY_DISABLED_TEXT
|
| return ""
|
|
|
| def get_vram_mode(self) -> str:
|
| server_config = self._server_config()
|
| return normalize_deepy_vram_mode(server_config.get(DEEPY_VRAM_MODE_KEY, DEEPY_VRAM_MODE_UNLOAD))
|
|
|
| def _ensure_vision_loaded(self, override_profile=None):
|
| self._deps.ensure_prompt_enhancer_loaded(override_profile=override_profile)
|
| image_caption_model = self._deps.get_image_caption_model()
|
| image_caption_processor = self._deps.get_image_caption_processor()
|
| if image_caption_model is None or image_caption_processor is None:
|
| raise gr.Error("Prompt enhancer vision runtime is not available.")
|
| return image_caption_model, image_caption_processor
|
|
|
| def _unload_weights(self) -> None:
|
| enhancer_offloadobj = self._deps.get_enhancer_offloadobj()
|
| if enhancer_offloadobj is not None:
|
| enhancer_offloadobj.unload_all()
|
|
|
| def _build_preload_release_callback(self) -> Callable[[], None]:
|
| def _release_preloaded_runtime() -> None:
|
| try:
|
| self._deps.unload_prompt_enhancer_runtime()
|
| finally:
|
| self._unload_weights()
|
|
|
| return _release_preloaded_runtime
|
|
|
| def release_vram(self, state, clear_session_state = False, discard_runtime_snapshot = False):
|
| session = get_or_create_assistant_session(state)
|
| release_callback = session.release_vram_callback
|
| session.release_vram_callback = None
|
| session.discard_runtime_snapshot_on_release = bool(discard_runtime_snapshot)
|
| self._deps.clear_gpu_resident(state)
|
| try:
|
| if callable(release_callback):
|
| release_callback()
|
| finally:
|
| if discard_runtime_snapshot:
|
| session.runtime_snapshot = None
|
| if len(session.rendered_token_ids) == 0:
|
| session.pending_replay_reason = ""
|
| session.discard_runtime_snapshot_on_release = False
|
| if clear_session_state:
|
| clear_assistant_session(session)
|
|
|
| def preload_cli_runtime(self, state, override_profile=None) -> dict[str, Any]:
|
| self._sync_debug_enabled()
|
| self._deps.clear_gpu_resident(state)
|
| self._deps.acquire_gpu(state)
|
| keep_resident = False
|
| warmed_vllm = False
|
| try:
|
| model, _tokenizer = self._deps.ensure_prompt_enhancer_loaded(override_profile=override_profile)
|
| from shared.prompt_enhancer import qwen35_text
|
|
|
| if qwen35_text._use_vllm_prompt_enhancer(model):
|
| engine = qwen35_text._get_or_create_vllm_engine(model, usage_mode="assistant")
|
| engine.reserve_runtime(prompt_len=64, max_tokens=1, cfg_scale=1.0)
|
| engine._ensure_llm()
|
| llm = getattr(engine, "_llm", None)
|
| if llm is None:
|
| raise RuntimeError("Assistant NanoVLLM runtime is not available.")
|
| llm.model_runner.ensure_runtime_ready()
|
| engine.release_runtime_allocations()
|
| warmed_vllm = True
|
| keep_resident = True
|
| return {"status": "ready", "warmed_vllm": warmed_vllm}
|
| finally:
|
| self._deps.release_gpu(
|
| state,
|
| keep_resident=keep_resident,
|
| release_vram_callback=self._build_preload_release_callback() if keep_resident else None,
|
| force_release_on_acquire=True,
|
| )
|
|
|
| def update_tool_ui_settings(self, state, *, auto_cancel_queue_tasks=None, use_template_properties=None, width=None, height=None, num_frames=None, seed=None, video_with_speech_variant=None, image_generator_variant=None, image_editor_variant=None, video_generator_variant=None, speech_from_description_variant=None, speech_from_sample_variant=None, persist=False):
|
| session = get_or_create_assistant_session(state)
|
| normalized = set_assistant_tool_ui_settings(
|
| session,
|
| auto_cancel_queue_tasks=auto_cancel_queue_tasks,
|
| use_template_properties=use_template_properties,
|
| width=width,
|
| height=height,
|
| num_frames=num_frames,
|
| seed=seed,
|
| video_with_speech_variant=video_with_speech_variant,
|
| image_generator_variant=image_generator_variant,
|
| image_editor_variant=image_editor_variant,
|
| video_generator_variant=video_generator_variant,
|
| speech_from_description_variant=speech_from_description_variant,
|
| speech_from_sample_variant=speech_from_sample_variant,
|
| )
|
| if persist:
|
| server_config = self._server_config()
|
| server_config_filename = str(self._deps.get_server_config_filename() or "").strip()
|
| if deepy_ui_settings.store_assistant_tool_ui_settings(server_config, normalized):
|
| set_deepy_runtime_config(server_config, server_config_filename)
|
| if len(server_config_filename) > 0:
|
| with open(server_config_filename, "w", encoding="utf-8") as writer:
|
| writer.write(json.dumps(server_config, indent=4))
|
| return normalized
|
|
|
| def persist_auto_cancel_queue_tasks(self, state, auto_cancel_queue_tasks):
|
| session = get_or_create_assistant_session(state)
|
| current = dict(session.tool_ui_settings or deepy_ui_settings.normalize_assistant_tool_ui_settings())
|
| current["auto_cancel_queue_tasks"] = auto_cancel_queue_tasks
|
| normalized = deepy_ui_settings.normalize_assistant_tool_ui_settings(**current)
|
| session.tool_ui_settings = dict(normalized)
|
| server_config = self._server_config()
|
| server_config_filename = str(self._deps.get_server_config_filename() or "").strip()
|
| if deepy_ui_settings.store_assistant_tool_ui_settings(server_config, normalized):
|
| set_deepy_runtime_config(server_config, server_config_filename)
|
| if len(server_config_filename) > 0:
|
| with open(server_config_filename, "w", encoding="utf-8") as writer:
|
| writer.write(json.dumps(server_config, indent=4))
|
| return normalized["auto_cancel_queue_tasks"]
|
|
|
| def store_selected_video_time(self, state, current_time):
|
| gen = self._deps.get_gen_info(state)
|
| try:
|
| value = float(current_time)
|
| except Exception:
|
| value = None
|
| gen["selected_video_time"] = None if value is None or value < 0 else value
|
|
|
| def create_tools(self, state, send_cmd, session = None):
|
| active_session = get_or_create_assistant_session(state) if session is None else session
|
| gen = self._deps.get_gen_info(state)
|
| return AssistantTools(
|
| gen,
|
| self._deps.get_processed_queue,
|
| send_cmd,
|
| session=active_session,
|
| get_output_filepath=self._deps.get_output_filepath,
|
| record_file_metadata=self._deps.record_file_metadata,
|
| get_server_config=self._server_config,
|
| )
|
|
|
| def run_assistant_prompt_turn(self, state, model_def, prompt_enhancer_modes, original_prompts, seed, override_profile=None, send_cmd=None, tools=None) -> None:
|
| debug_enabled = self._sync_debug_enabled()
|
| server_config = self._server_config()
|
| if not normalize_deepy_enabled(server_config.get(DEEPY_ENABLED_KEY, 0)):
|
| raise gr.Error(_DEEPY_DISABLED_TEXT)
|
| if not deepy_requirement_met(server_config):
|
| raise gr.Error(_DEEPY_REQUIREMENT_TEXT)
|
| if send_cmd is None or tools is None:
|
| raise gr.Error("Assistant mode requires a command stream and a tool registry.")
|
| enhancer_temperature = server_config.get("prompt_enhancer_temperature", 0.6)
|
| enhancer_top_p = server_config.get("prompt_enhancer_top_p", 0.9)
|
| randomize_seed = server_config.get("prompt_enhancer_randomize_seed", True)
|
| assistant_seed = secrets.randbits(32) if randomize_seed else (seed if seed is not None and seed >= 0 else 0)
|
| session = get_or_create_assistant_session(state)
|
| assistant_model_def = model_def
|
| _assistant_instructions, assistant_max_new_tokens = self._deps.resolve_prompt_enhancer_settings(assistant_model_def, prompt_enhancer_modes, is_image=False, text_encoder_max_tokens=1024)
|
| assistant = AssistantEngine(
|
| session,
|
| AssistantRuntimeHooks(
|
| acquire_gpu=lambda: self._deps.acquire_gpu(state),
|
| release_gpu=lambda keep_resident = False, release_vram_callback = None, force_release_on_acquire = True: self._deps.release_gpu(state, keep_resident=keep_resident, release_vram_callback=release_vram_callback, force_release_on_acquire=force_release_on_acquire),
|
| register_gpu_resident=lambda release_vram_callback = None, force_release_on_acquire = True: self._deps.register_gpu_resident(state, release_vram_callback=release_vram_callback, force_release_on_acquire=force_release_on_acquire),
|
| clear_gpu_resident=lambda: self._deps.clear_gpu_resident(state),
|
| ensure_loaded=lambda: self._deps.ensure_prompt_enhancer_loaded(override_profile=override_profile),
|
| unload_runtime=self._deps.unload_prompt_enhancer_runtime,
|
| unload_weights=self._unload_weights,
|
| ensure_vision_loaded=lambda: self._ensure_vision_loaded(override_profile=override_profile),
|
| ),
|
| tools,
|
| send_cmd,
|
| debug_enabled=debug_enabled,
|
| thinking_enabled="K" in prompt_enhancer_modes,
|
| vram_mode=self.get_vram_mode(),
|
| )
|
| assistant.run_turn(
|
| original_prompts[0] if len(original_prompts) > 0 else "",
|
| max_new_tokens=max(1024, int(assistant_max_new_tokens)),
|
| seed=assistant_seed,
|
| do_sample=True,
|
| temperature=enhancer_temperature,
|
| top_p=enhancer_top_p,
|
| )
|
|
|
| def ask_ai(self, state, ask_request):
|
| self._sync_debug_enabled()
|
|
|
| def get_refresh_id():
|
| return str(time.time()) + "_" + str(self._deps.get_new_refresh_id())
|
|
|
| def drain_chat_output_batch(first_payload):
|
| payloads = [first_payload]
|
| while True:
|
| next_item = com_stream.output_queue.top()
|
| if not isinstance(next_item, tuple) or len(next_item) < 1 or next_item[0] != "chat_output":
|
| break
|
| _cmd, next_payload = com_stream.output_queue.pop()
|
| payloads.append(next_payload)
|
| return assistant_chat.build_event_batch(payloads)
|
|
|
| session = get_or_create_assistant_session(state)
|
| ask_request = str(ask_request or "").strip()
|
| if len(ask_request) == 0:
|
| yield gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
| return
|
| if not self.is_available():
|
| error_turn_id = assistant_chat.create_assistant_turn(session)
|
| error_event = assistant_chat.set_assistant_content(session, error_turn_id, self.requirement_error_text())
|
| yield error_event if error_event is not None else gr.update(), gr.update(), gr.update(value=""), gr.update(), gr.update()
|
| return
|
| gen = self._deps.get_gen_info(state)
|
| com_stream = AsyncStream()
|
| send_cmd = com_stream.output_queue.push
|
| queued = session.worker_active or session.queued_job_count > 0
|
| queued_epoch = session.chat_epoch
|
| session.queued_job_count += 1
|
| user_message_id, _user_event = assistant_chat.add_user_message(session, ask_request, queued=queued)
|
| yield assistant_chat.build_sync_event(session), gr.update(), gr.update(value=""), gr.update(), gr.update()
|
| if queued:
|
| yield assistant_chat.build_status_event("Queued behind the current assistant task.", kind="queued"), gr.update(), gr.update(), gr.update(), gr.update()
|
|
|
| def queue_worker_func():
|
| session.queued_job_count = max(0, session.queued_job_count - 1)
|
| if queued_epoch != session.chat_epoch:
|
| send_cmd("exit", None)
|
| return
|
| session.interrupt_requested = False
|
| session.control_queue = com_stream.output_queue
|
| session.worker_active = True
|
| begin_assistant_turn(session, user_message_id, ask_request)
|
| send_cmd("chat_output", assistant_chat.build_sync_event(session))
|
| queued_badge_event = assistant_chat.set_message_badge(session, user_message_id, None)
|
| if queued_badge_event is not None:
|
| send_cmd("chat_output", queued_badge_event)
|
| my_tools = self.create_tools(state, send_cmd, session=session)
|
| try:
|
| self._deps.exec_prompt_enhancer_engine(state, None, "AK", [ask_request], None, None, False, False, 0, None, 3.5, send_cmd, my_tools)
|
| except Exception as e:
|
| traceback.print_exc()
|
| error_turn_id = assistant_chat.create_assistant_turn(session)
|
| error_event = assistant_chat.set_assistant_content(session, error_turn_id, f"Assistant crashed: {e}")
|
| if error_event is not None:
|
| send_cmd("chat_output", error_event)
|
| send_cmd("chat_output", assistant_chat.build_status_event(None, visible=False))
|
| finally:
|
| session.worker_active = False
|
| if session.control_queue is com_stream.output_queue:
|
| session.control_queue = None
|
| if queued_epoch == session.chat_epoch:
|
| send_cmd("chat_output", assistant_chat.build_sync_event(session))
|
| session.interrupt_requested = False
|
| send_cmd("exit", None)
|
|
|
| async_run_in("assistant", queue_worker_func)
|
| while True:
|
| cmd, data = com_stream.output_queue.next()
|
| if cmd == "console_output":
|
| print(data)
|
| elif cmd == "chat_output":
|
| yield drain_chat_output_batch(data), gr.update(), gr.update(), gr.update(), gr.update()
|
| elif cmd == "load_queue_trigger":
|
| yield gr.update(), str(get_refresh_id()), gr.update(), gr.update(), gr.update()
|
| elif cmd == "abort_client_id":
|
| yield gr.update(), gr.update(), gr.update(), gr.update(), str(data or "")
|
| elif cmd == "refresh_gallery":
|
| yield gr.update(), gr.update(), gr.update(), str(get_refresh_id()), gr.update()
|
| elif cmd == "error":
|
| error_turn_id = assistant_chat.create_assistant_turn(session)
|
| error_event = assistant_chat.set_assistant_content(session, error_turn_id, str(data or "Assistant error."))
|
| yield error_event if error_event is not None else gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
| elif cmd == "exit":
|
| break
|
|
|
| def stop_ai(self, state):
|
| session = get_or_create_assistant_session(state)
|
| if not session.worker_active:
|
| return gr.update(), gr.update(), gr.update(), gr.update()
|
| request_assistant_interrupt(session)
|
| return assistant_chat.build_status_event(None, visible=False), gr.update(), gr.update(), gr.update()
|
|
|
| def reset_ai(self, state):
|
| session = get_or_create_assistant_session(state)
|
| if session.worker_active:
|
| request_assistant_reset(session)
|
| assistant_chat.reset_session_chat(session)
|
| else:
|
| self.release_vram(state, True)
|
| session.chat_html = ""
|
| return assistant_chat.build_reset_event(), gr.update(), gr.update(value=""), gr.update()
|
|
|
|
|
| def create_controller(**deps_kwargs) -> DeepyController:
|
| return DeepyController(DeepyDeps(**deps_kwargs))
|
|
|