from __future__ import annotations import copy import gc import os import time from dataclasses import dataclass from pathlib import Path from typing import Any, Callable import gradio as gr import torch from shared.utils.hdr import tonemap_hdr_tensor_to_uint8 from shared.utils.virtual_media import build_virtual_media_path from . import constants as ui_constants from . import frame_planning as frames from . import media_io as media from . import prompt_schedule as prompts from . import status_ui from . import video_buffers as video from .mux_session import MuxSession USER_PROCESS_OUTPUT_KEYS = { "output_path", "output_dir", "save_path", "image_save_path", "audio_save_path", } def build_task_settings(process_settings: dict, *, is_user_process: bool) -> dict: settings = copy.deepcopy(process_settings) if is_user_process: for key in USER_PROCESS_OUTPUT_KEYS: settings.pop(key, None) settings["repeat_generation"] = 1 settings["batch_size"] = 1 api_settings = settings.get("_api") settings["_api"] = dict(api_settings) if isinstance(api_settings, dict) else {} if is_user_process: settings["_api"] = {key: value for key, value in settings["_api"].items() if key == "return_media"} settings["_api"]["return_media"] = True return settings @dataclass(frozen=True) class ProcessContext: state: dict | None process_settings: dict model_type: str process_is_hdr: bool is_user_process: bool has_outpaint_setting: bool uses_builtin_outpaint_ui: bool use_lora_strength_override: bool active_process_strength: float active_target_ratio: str source_path: str output_path: str selected_audio_track: int | None prompt_schedule: list[tuple[float, str]] default_prompt_text: str budget_resolution: str start_frame: int resumed_unique_frames: int requested_unique_frames: int overlap_frames: int processing_fps: float fps_float: float continued_mode: bool plans: list[frames.ChunkPlan] total_chunks_display: int ffmpeg_path: str use_live_av_mux: bool output_container: str exact_start_seconds: float timing_kwargs: Callable system_handler: Any = None system_target_control: str = "" @dataclass class ChunkProgress: completed_chunks: int current_chunk_display: int chunk_output_paths: list[str] last_segment_path: str | None write_state: MuxSession written_unique_frames: int = 0 resolved_resolution: str = "" resolved_width: int = 0 resolved_height: int = 0 continue_cache: Any = None @dataclass(frozen=True) class ChunkExecutionResult: written_unique_frames: int completed_chunks: int current_chunk_display: int resolved_resolution: str resolved_width: int resolved_height: int last_segment_path: str | None chunk_output_paths: list[str] continue_cache: Any = None class ChunkExecutor: def __init__(self, *, plugin, api_session, active_job: dict, preview_state: dict, ui_update, ui_skip, reset_live_chunk_status) -> None: self.plugin = plugin self.api_session = api_session self.active_job = active_job self.preview_state = preview_state self.ui_update = ui_update self.ui_skip = ui_skip self.reset_live_chunk_status = reset_live_chunk_status def run(self, context: ProcessContext, progress: ChunkProgress): for chunk_index, plan in enumerate(context.plans, start=1): callbacks = status_ui.ChunkCallbacks() last_html = "" actual_done = context.resumed_unique_frames + progress.written_unique_frames plan_overlap_frames = plan.overlap_frames plan_requested_frames = plan.requested_frames actual_control_start_frame = plan.control_start_frame actual_control_end_frame = actual_control_start_frame + plan_requested_frames - 1 overlap_buffer_start_frame = actual_control_start_frame model_video_length = plan_requested_frames if plan_overlap_frames <= 0 else plan_requested_frames - plan_overlap_frames + 1 needs_video_source = context.continued_mode or plan_overlap_frames > 0 print( f"[Process Full Video] Chunk {chunk_index}: control video {frames.describe_frame_range(actual_control_start_frame, plan_requested_frames)}; " + (f"overlap buffer {frames.describe_frame_range(overlap_buffer_start_frame, plan_overlap_frames)}" if needs_video_source else "overlap buffer not used") ) if context.system_handler is not None: if self.active_job.get("cancel_requested"): progress.write_state.stopped = True break settings = context.system_handler.build_queue_settings(context.process_settings, source_path=context.source_path, start_frame=actual_control_start_frame, frame_count=plan_requested_frames, target_control=context.system_target_control, seed=chunk_index, continue_cache=progress.continue_cache, audio_track_no=context.selected_audio_track) self.reset_live_chunk_status(context.state) job = self.api_session.submit_task(settings, callbacks=callbacks) self.active_job["job"] = job yield self.ui_update(start_enabled=False, abort_enabled=True) next_status_refresh_at = 0.0 stop_requested = False while not job.done: if self.active_job.get("cancel_requested") and not stop_requested: try: job.cancel() except RuntimeError as exc: print(f"[Process Full Video] Stop requested; WanGP abort bridge was not available: {exc}") progress.write_state.stopped = True stop_requested = True now = time.monotonic() if now >= next_status_refresh_at: html_value = status_ui.render_chunk_status_html( context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, callbacks.phase_label, callbacks.status_text, continued=context.continued_mode, phase_current_step=callbacks.current_step, phase_total_steps=callbacks.total_steps, prefer_status_phase=True, **context.timing_kwargs(progress.completed_chunks, callbacks.current_step, callbacks.total_steps), ) next_status_refresh_at = now + ui_constants.STATUS_REFRESH_INTERVAL_SECONDS if html_value != last_html: last_html = html_value yield self.ui_update(html_value) time.sleep(0.1) try: result = job.result() finally: if self.active_job.get("job") is job: self.active_job["job"] = None yield self.ui_update(start_enabled=False, abort_enabled=False) if not result.success: if result.cancelled: progress.write_state.stopped = True break errors = list(result.errors or []) raise gr.Error(str(errors[0] if errors else f"Chunk {chunk_index} failed.")) if self.active_job.get("cancel_requested"): progress.write_state.stopped = True break progress.chunk_output_paths.extend( str(Path(path).resolve()) for path in result.generated_files if isinstance(path, str) and len(path.strip()) > 0 and str(Path(path).resolve()) not in progress.chunk_output_paths ) progress.last_segment_path = media.get_last_generated_video_path(list(result.generated_files)) or progress.last_segment_path returned_video_item = next((item for item in result.artifacts if item.video_tensor_uint8 is not None), None) video_tensor_uint8 = None if returned_video_item is None else returned_video_item.video_tensor_uint8 if not torch.is_tensor(video_tensor_uint8): raise gr.Error(f"Chunk {chunk_index} completed without returned video tensor data.") video_tensor_uint8 = video_tensor_uint8.detach().cpu().contiguous() api_settings = settings.get("_api") if isinstance(api_settings, dict): api_settings["flashvsr_continue_cache"] = None release_input_payload = getattr(job, "release_input_payload", None) if callable(release_input_payload): release_input_payload() progress.continue_cache = getattr(returned_video_item, "flashvsr_continue_cache", None) settings = None gc.collect() returned_frame_count = int(video_tensor_uint8.shape[1]) print(f"[Process Full Video] Chunk {chunk_index}: returned video tensor has {returned_frame_count} frame(s); control video lasts {plan_requested_frames} frame(s)") chunk_width, chunk_height = video.get_video_tensor_resolution(video_tensor_uint8) chunk_resolution = f"{chunk_width}x{chunk_height}" print(f"[Process Full Video] Chunk {chunk_index}: generated chunk resolution {chunk_resolution}") if len(progress.resolved_resolution) == 0: progress.resolved_resolution = chunk_resolution progress.resolved_width = chunk_width progress.resolved_height = chunk_height elif chunk_resolution != progress.resolved_resolution: raise gr.Error(f"Chunk {chunk_index} changed output resolution from {progress.resolved_resolution} to {chunk_resolution}.") remaining_unique_frames = context.requested_unique_frames - (context.resumed_unique_frames + progress.written_unique_frames) next_overlap_frames = context.plans[chunk_index].overlap_frames if chunk_index < len(context.plans) else 0 leading_overlap_already_written = chunk_index == 1 and context.resumed_unique_frames > 0 write_start = plan_overlap_frames if leading_overlap_already_written else 0 write_end = plan_requested_frames - next_overlap_frames frames_to_write = write_end - write_start if frames_to_write <= 0: raise gr.Error(f"Chunk {chunk_index} has no new frame to write after keeping {next_overlap_frames} lookahead frame(s).") if frames_to_write > remaining_unique_frames: raise gr.Error(f"Chunk {chunk_index} would write {frames_to_write} frame(s), but only {remaining_unique_frames} frame(s) remain.") if returned_frame_count < write_end: raise gr.Error(f"Chunk {chunk_index} returned {returned_frame_count} frame(s), but {write_end} frame(s) were required.") source_audio_duration_seconds = float(frames.count_planned_unique_frames(context.plans)) / float(context.fps_float) if context.use_live_av_mux else None progress.write_state.ensure_started( server_config=self.plugin.server_config, ffmpeg_path=context.ffmpeg_path, process_is_hdr=False, use_live_av_mux=context.use_live_av_mux, output_container=context.output_container, source_path=context.source_path, exact_start_seconds=context.exact_start_seconds, selected_audio_track=context.selected_audio_track, resolved_width=progress.resolved_width, resolved_height=progress.resolved_height, fps_float=context.fps_float, source_audio_duration_seconds=source_audio_duration_seconds, ) if context.continued_mode and progress.write_state.output_path_for_write != context.output_path and callable(getattr(context.system_handler, "move_continue_cache", None)): context.system_handler.move_continue_cache(context.output_path, progress.write_state.output_path_for_write) last_frame_tensor = progress.write_state.write_chunk(process_is_hdr=False, video_tensor_hdr=None, video_tensor_uint8=video_tensor_uint8, start_frame=write_start, frame_count=frames_to_write) progress.written_unique_frames += frames_to_write self.preview_state["image"] = video.frame_to_image(last_frame_tensor) if progress.continue_cache is not None and hasattr(context.system_handler, "save_continue_cache"): context.system_handler.save_continue_cache(progress.continue_cache, progress.write_state.output_path_for_write, metadata={"written_unique_frames": int(context.resumed_unique_frames + progress.written_unique_frames), "chunk": int(chunk_index)}) video_tensor_uint8 = None progress.completed_chunks += 1 if chunk_index < len(context.plans): progress.current_chunk_display = progress.completed_chunks + 1 yield self.ui_update(status_ui.render_chunk_status_html(context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, "Starting new Chunk", f"Chunk {progress.completed_chunks} finished with {frames_to_write} written frame(s). Preparing next chunk...", continued=context.continued_mode, **context.timing_kwargs(progress.completed_chunks)), self.ui_skip, str(time.time_ns())) else: progress.current_chunk_display = progress.completed_chunks yield self.ui_update(status_ui.render_chunk_status_html(context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, "Chunk Completed", f"Chunk {progress.completed_chunks} finished with {frames_to_write} written frame(s).", continued=context.continued_mode, **context.timing_kwargs(progress.completed_chunks)), self.ui_skip, str(time.time_ns())) continue settings = build_task_settings(context.process_settings, is_user_process=context.is_user_process) chunk_prompt_start_seconds = float(actual_done) / float(context.fps_float) settings["model_type"] = context.model_type settings["prompt"] = prompts.resolve_prompt_for_chunk(context.prompt_schedule, chunk_prompt_start_seconds, context.default_prompt_text) settings["resolution"] = progress.resolved_resolution or context.budget_resolution settings["video_length"] = model_video_length settings["sliding_window_overlap"] = plan_overlap_frames if plan_overlap_frames > 0 else 1 settings["image_prompt_type"] = "V" if needs_video_source else "" settings["audio_prompt_type"] = "K" if context.is_user_process: settings["force_fps"] = "control" settings["video_guide"] = build_virtual_media_path(context.source_path, start_frame=actual_control_start_frame, end_frame=actual_control_end_frame, audio_track_no=context.selected_audio_track) if context.uses_builtin_outpaint_ui: settings["video_guide_outpainting_ratio"] = context.active_target_ratio elif not context.has_outpaint_setting: settings.pop("video_guide_outpainting_ratio", None) if context.use_lora_strength_override: settings["loras_multipliers"] = str(context.active_process_strength) if needs_video_source: settings["video_source"] = video.build_process_full_video_source_path(hdr=context.process_is_hdr) else: settings.pop("video_source", None) self.reset_live_chunk_status(context.state) job = self.api_session.submit_task(settings, callbacks=callbacks) self.active_job["job"] = job yield self.ui_update(start_enabled=False, abort_enabled=True) next_status_refresh_at = 0.0 stop_requested = False while not job.done: if self.active_job.get("cancel_requested") and not stop_requested: try: job.cancel() except RuntimeError as exc: print(f"[Process Full Video] Stop requested; WanGP abort bridge was not available: {exc}") progress.write_state.stopped = True stop_requested = True now = time.monotonic() if now >= next_status_refresh_at: html_value = status_ui.render_chunk_status_html( context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, callbacks.phase_label, callbacks.status_text, continued=context.continued_mode, phase_current_step=callbacks.current_step, phase_total_steps=callbacks.total_steps, prefer_status_phase=True, **context.timing_kwargs(progress.completed_chunks, callbacks.current_step, callbacks.total_steps), ) next_status_refresh_at = now + ui_constants.STATUS_REFRESH_INTERVAL_SECONDS if html_value != last_html: last_html = html_value yield self.ui_update(html_value) time.sleep(0.1) try: result = job.result() finally: if self.active_job.get("job") is job: self.active_job["job"] = None yield self.ui_update(start_enabled=False, abort_enabled=False) if not result.success: if result.cancelled: progress.write_state.stopped = True break errors = list(result.errors or []) raise gr.Error(str(errors[0] if errors else f"Chunk {chunk_index} failed.")) progress.chunk_output_paths.extend( str(Path(path).resolve()) for path in result.generated_files if isinstance(path, str) and len(path.strip()) > 0 and str(Path(path).resolve()) not in progress.chunk_output_paths ) progress.last_segment_path = media.get_last_generated_video_path(list(result.generated_files)) or progress.last_segment_path returned_video_item = next((item for item in result.artifacts if item.video_tensor_hdr is not None), None) if context.process_is_hdr else next((item for item in result.artifacts if item.video_tensor_uint8 is not None), None) returned_tensor = None if returned_video_item is None else (returned_video_item.video_tensor_hdr if context.process_is_hdr else returned_video_item.video_tensor_uint8) if returned_video_item is None or not torch.is_tensor(returned_tensor): raise gr.Error(f"Chunk {chunk_index} completed without returned video tensor data.") video_tensor_hdr = returned_tensor.detach().cpu() if context.process_is_hdr else None video_tensor_uint8 = tonemap_hdr_tensor_to_uint8(video_tensor_hdr) if context.process_is_hdr else returned_tensor.detach().cpu() returned_frame_count = int(video_tensor_uint8.shape[1]) expected_frame_count = plan_requested_frames minimum_returned_frames = expected_frame_count - 1 if expected_frame_count > 1 else 1 if not context.process_is_hdr and returned_frame_count < minimum_returned_frames: video_candidates = [path for path in result.generated_files if isinstance(path, str) and os.path.isfile(path) and str(Path(path).suffix).lower() in {".mp4", ".mkv", ".mov", ".avi"}] if video_candidates: decoded_tensor = video.load_video_tensor_from_file(video_candidates[0]) decoded_frame_count = int(decoded_tensor.shape[1]) print(f"[Process Full Video] Chunk {chunk_index}: returned video tensor has {returned_frame_count} frame(s); decoded chunk file has {decoded_frame_count} frame(s)") if decoded_frame_count >= minimum_returned_frames: video_tensor_uint8 = decoded_tensor returned_frame_count = decoded_frame_count print(f"[Process Full Video] Chunk {chunk_index}: returned video tensor has {returned_frame_count} frame(s); control video lasts {expected_frame_count} frame(s)") chunk_width, chunk_height = video.get_video_tensor_resolution(video_tensor_uint8) chunk_resolution = f"{chunk_width}x{chunk_height}" print(f"[Process Full Video] Chunk {chunk_index}: generated chunk resolution {chunk_resolution}") if len(progress.resolved_resolution) == 0: progress.resolved_resolution = chunk_resolution progress.resolved_width = chunk_width progress.resolved_height = chunk_height elif chunk_resolution != progress.resolved_resolution: raise gr.Error(f"Chunk {chunk_index} changed output resolution from {progress.resolved_resolution} to {chunk_resolution}.") skip_frames = plan_overlap_frames remaining_unique_frames = context.requested_unique_frames - (context.resumed_unique_frames + progress.written_unique_frames) expected_unique_frames = plan_requested_frames - skip_frames if expected_unique_frames <= 0: raise gr.Error(f"Chunk {chunk_index} has no writable frame in the computed plan.") if expected_unique_frames > remaining_unique_frames: raise gr.Error(f"Chunk {chunk_index} would write {expected_unique_frames} frame(s), but only {remaining_unique_frames} frame(s) remain.") writable_frame_count = int(video_tensor_uint8.shape[1]) - skip_frames if writable_frame_count < expected_unique_frames: raise gr.Error(f"Chunk {chunk_index} returned {writable_frame_count} writable frame(s), but {expected_unique_frames} frame(s) were required.") frames_to_write = expected_unique_frames source_audio_duration_seconds = float(frames.count_planned_unique_frames(context.plans)) / float(context.fps_float) if context.use_live_av_mux else None progress.write_state.ensure_started( server_config=self.plugin.server_config, ffmpeg_path=context.ffmpeg_path, process_is_hdr=context.process_is_hdr, use_live_av_mux=context.use_live_av_mux, output_container=context.output_container, source_path=context.source_path, exact_start_seconds=context.exact_start_seconds, selected_audio_track=context.selected_audio_track, resolved_width=progress.resolved_width, resolved_height=progress.resolved_height, fps_float=context.fps_float, source_audio_duration_seconds=source_audio_duration_seconds, ) last_frame_tensor = progress.write_state.write_chunk(process_is_hdr=context.process_is_hdr, video_tensor_hdr=video_tensor_hdr, video_tensor_uint8=video_tensor_uint8, start_frame=skip_frames, frame_count=frames_to_write) progress.written_unique_frames += frames_to_write self.preview_state["image"] = video.frame_to_image(last_frame_tensor) overlap_source_tensor = video_tensor_hdr if context.process_is_hdr and video_tensor_hdr is not None else video_tensor_uint8 next_overlap_tensor = video.update_process_full_video_overlap_buffer(overlap_source_tensor[:, skip_frames:skip_frames + frames_to_write], context.overlap_frames, context.processing_fps, hdr=context.process_is_hdr) if next_overlap_tensor is not None and int(next_overlap_tensor.shape[1]) > 0: next_overlap_count = int(next_overlap_tensor.shape[1]) next_overlap_start_frame = context.start_frame + context.resumed_unique_frames + progress.written_unique_frames - next_overlap_count print(f"[Process Full Video] Chunk {chunk_index}: next overlap buffer {frames.describe_frame_range(next_overlap_start_frame, next_overlap_count)}") progress.completed_chunks += 1 progress.chunk_output_paths = media.delete_released_chunk_outputs(context.state, progress.chunk_output_paths, preserve_paths=[progress.last_segment_path] if progress.last_segment_path else None) if chunk_index < len(context.plans): progress.current_chunk_display = progress.completed_chunks + 1 yield self.ui_update(status_ui.render_chunk_status_html(context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, "Starting new Chunk", f"Chunk {progress.completed_chunks} finished with {frames_to_write} written frame(s). Preparing next chunk...", continued=context.continued_mode, **context.timing_kwargs(progress.completed_chunks)), self.ui_skip, str(time.time_ns())) else: progress.current_chunk_display = progress.completed_chunks yield self.ui_update(status_ui.render_chunk_status_html(context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, "Chunk Completed", f"Chunk {progress.completed_chunks} finished with {frames_to_write} written frame(s).", continued=context.continued_mode, **context.timing_kwargs(progress.completed_chunks)), self.ui_skip, str(time.time_ns())) return ChunkExecutionResult( written_unique_frames=progress.written_unique_frames, completed_chunks=progress.completed_chunks, current_chunk_display=progress.current_chunk_display, resolved_resolution=progress.resolved_resolution, resolved_width=progress.resolved_width, resolved_height=progress.resolved_height, last_segment_path=progress.last_segment_path, chunk_output_paths=progress.chunk_output_paths, continue_cache=progress.continue_cache, )