1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
27 kB
from __future__ import annotations
import copy
import gc
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable
import gradio as gr
import torch
from shared.utils.hdr import tonemap_hdr_tensor_to_uint8
from shared.utils.virtual_media import build_virtual_media_path
from . import constants as ui_constants
from . import frame_planning as frames
from . import media_io as media
from . import prompt_schedule as prompts
from . import status_ui
from . import video_buffers as video
from .mux_session import MuxSession
USER_PROCESS_OUTPUT_KEYS = {
"output_path",
"output_dir",
"save_path",
"image_save_path",
"audio_save_path",
}
def build_task_settings(process_settings: dict, *, is_user_process: bool) -> dict:
settings = copy.deepcopy(process_settings)
if is_user_process:
for key in USER_PROCESS_OUTPUT_KEYS:
settings.pop(key, None)
settings["repeat_generation"] = 1
settings["batch_size"] = 1
api_settings = settings.get("_api")
settings["_api"] = dict(api_settings) if isinstance(api_settings, dict) else {}
if is_user_process:
settings["_api"] = {key: value for key, value in settings["_api"].items() if key == "return_media"}
settings["_api"]["return_media"] = True
return settings
@dataclass(frozen=True)
class ProcessContext:
state: dict | None
process_settings: dict
model_type: str
process_is_hdr: bool
is_user_process: bool
has_outpaint_setting: bool
uses_builtin_outpaint_ui: bool
use_lora_strength_override: bool
active_process_strength: float
active_target_ratio: str
source_path: str
output_path: str
selected_audio_track: int | None
prompt_schedule: list[tuple[float, str]]
default_prompt_text: str
budget_resolution: str
start_frame: int
resumed_unique_frames: int
requested_unique_frames: int
overlap_frames: int
processing_fps: float
fps_float: float
continued_mode: bool
plans: list[frames.ChunkPlan]
total_chunks_display: int
ffmpeg_path: str
use_live_av_mux: bool
output_container: str
exact_start_seconds: float
timing_kwargs: Callable
system_handler: Any = None
system_target_control: str = ""
@dataclass
class ChunkProgress:
completed_chunks: int
current_chunk_display: int
chunk_output_paths: list[str]
last_segment_path: str | None
write_state: MuxSession
written_unique_frames: int = 0
resolved_resolution: str = ""
resolved_width: int = 0
resolved_height: int = 0
continue_cache: Any = None
@dataclass(frozen=True)
class ChunkExecutionResult:
written_unique_frames: int
completed_chunks: int
current_chunk_display: int
resolved_resolution: str
resolved_width: int
resolved_height: int
last_segment_path: str | None
chunk_output_paths: list[str]
continue_cache: Any = None
class ChunkExecutor:
def __init__(self, *, plugin, api_session, active_job: dict, preview_state: dict, ui_update, ui_skip, reset_live_chunk_status) -> None:
self.plugin = plugin
self.api_session = api_session
self.active_job = active_job
self.preview_state = preview_state
self.ui_update = ui_update
self.ui_skip = ui_skip
self.reset_live_chunk_status = reset_live_chunk_status
def run(self, context: ProcessContext, progress: ChunkProgress):
for chunk_index, plan in enumerate(context.plans, start=1):
callbacks = status_ui.ChunkCallbacks()
last_html = ""
actual_done = context.resumed_unique_frames + progress.written_unique_frames
plan_overlap_frames = plan.overlap_frames
plan_requested_frames = plan.requested_frames
actual_control_start_frame = plan.control_start_frame
actual_control_end_frame = actual_control_start_frame + plan_requested_frames - 1
overlap_buffer_start_frame = actual_control_start_frame
model_video_length = plan_requested_frames if plan_overlap_frames <= 0 else plan_requested_frames - plan_overlap_frames + 1
needs_video_source = context.continued_mode or plan_overlap_frames > 0
print(
f"[Process Full Video] Chunk {chunk_index}: control video {frames.describe_frame_range(actual_control_start_frame, plan_requested_frames)}; "
+ (f"overlap buffer {frames.describe_frame_range(overlap_buffer_start_frame, plan_overlap_frames)}" if needs_video_source else "overlap buffer not used")
)
if context.system_handler is not None:
if self.active_job.get("cancel_requested"):
progress.write_state.stopped = True
break
settings = context.system_handler.build_queue_settings(context.process_settings, source_path=context.source_path, start_frame=actual_control_start_frame, frame_count=plan_requested_frames, target_control=context.system_target_control, seed=chunk_index, continue_cache=progress.continue_cache, audio_track_no=context.selected_audio_track)
self.reset_live_chunk_status(context.state)
job = self.api_session.submit_task(settings, callbacks=callbacks)
self.active_job["job"] = job
yield self.ui_update(start_enabled=False, abort_enabled=True)
next_status_refresh_at = 0.0
stop_requested = False
while not job.done:
if self.active_job.get("cancel_requested") and not stop_requested:
try:
job.cancel()
except RuntimeError as exc:
print(f"[Process Full Video] Stop requested; WanGP abort bridge was not available: {exc}")
progress.write_state.stopped = True
stop_requested = True
now = time.monotonic()
if now >= next_status_refresh_at:
html_value = status_ui.render_chunk_status_html(
context.total_chunks_display,
progress.completed_chunks,
progress.current_chunk_display,
callbacks.phase_label,
callbacks.status_text,
continued=context.continued_mode,
phase_current_step=callbacks.current_step,
phase_total_steps=callbacks.total_steps,
prefer_status_phase=True,
**context.timing_kwargs(progress.completed_chunks, callbacks.current_step, callbacks.total_steps),
)
next_status_refresh_at = now + ui_constants.STATUS_REFRESH_INTERVAL_SECONDS
if html_value != last_html:
last_html = html_value
yield self.ui_update(html_value)
time.sleep(0.1)
try:
result = job.result()
finally:
if self.active_job.get("job") is job:
self.active_job["job"] = None
yield self.ui_update(start_enabled=False, abort_enabled=False)
if not result.success:
if result.cancelled:
progress.write_state.stopped = True
break
errors = list(result.errors or [])
raise gr.Error(str(errors[0] if errors else f"Chunk {chunk_index} failed."))
if self.active_job.get("cancel_requested"):
progress.write_state.stopped = True
break
progress.chunk_output_paths.extend(
str(Path(path).resolve())
for path in result.generated_files
if isinstance(path, str) and len(path.strip()) > 0 and str(Path(path).resolve()) not in progress.chunk_output_paths
)
progress.last_segment_path = media.get_last_generated_video_path(list(result.generated_files)) or progress.last_segment_path
returned_video_item = next((item for item in result.artifacts if item.video_tensor_uint8 is not None), None)
video_tensor_uint8 = None if returned_video_item is None else returned_video_item.video_tensor_uint8
if not torch.is_tensor(video_tensor_uint8):
raise gr.Error(f"Chunk {chunk_index} completed without returned video tensor data.")
video_tensor_uint8 = video_tensor_uint8.detach().cpu().contiguous()
api_settings = settings.get("_api")
if isinstance(api_settings, dict):
api_settings["flashvsr_continue_cache"] = None
release_input_payload = getattr(job, "release_input_payload", None)
if callable(release_input_payload):
release_input_payload()
progress.continue_cache = getattr(returned_video_item, "flashvsr_continue_cache", None)
settings = None
gc.collect()
returned_frame_count = int(video_tensor_uint8.shape[1])
print(f"[Process Full Video] Chunk {chunk_index}: returned video tensor has {returned_frame_count} frame(s); control video lasts {plan_requested_frames} frame(s)")
chunk_width, chunk_height = video.get_video_tensor_resolution(video_tensor_uint8)
chunk_resolution = f"{chunk_width}x{chunk_height}"
print(f"[Process Full Video] Chunk {chunk_index}: generated chunk resolution {chunk_resolution}")
if len(progress.resolved_resolution) == 0:
progress.resolved_resolution = chunk_resolution
progress.resolved_width = chunk_width
progress.resolved_height = chunk_height
elif chunk_resolution != progress.resolved_resolution:
raise gr.Error(f"Chunk {chunk_index} changed output resolution from {progress.resolved_resolution} to {chunk_resolution}.")
remaining_unique_frames = context.requested_unique_frames - (context.resumed_unique_frames + progress.written_unique_frames)
next_overlap_frames = context.plans[chunk_index].overlap_frames if chunk_index < len(context.plans) else 0
leading_overlap_already_written = chunk_index == 1 and context.resumed_unique_frames > 0
write_start = plan_overlap_frames if leading_overlap_already_written else 0
write_end = plan_requested_frames - next_overlap_frames
frames_to_write = write_end - write_start
if frames_to_write <= 0:
raise gr.Error(f"Chunk {chunk_index} has no new frame to write after keeping {next_overlap_frames} lookahead frame(s).")
if frames_to_write > remaining_unique_frames:
raise gr.Error(f"Chunk {chunk_index} would write {frames_to_write} frame(s), but only {remaining_unique_frames} frame(s) remain.")
if returned_frame_count < write_end:
raise gr.Error(f"Chunk {chunk_index} returned {returned_frame_count} frame(s), but {write_end} frame(s) were required.")
source_audio_duration_seconds = float(frames.count_planned_unique_frames(context.plans)) / float(context.fps_float) if context.use_live_av_mux else None
progress.write_state.ensure_started(
server_config=self.plugin.server_config,
ffmpeg_path=context.ffmpeg_path,
process_is_hdr=False,
use_live_av_mux=context.use_live_av_mux,
output_container=context.output_container,
source_path=context.source_path,
exact_start_seconds=context.exact_start_seconds,
selected_audio_track=context.selected_audio_track,
resolved_width=progress.resolved_width,
resolved_height=progress.resolved_height,
fps_float=context.fps_float,
source_audio_duration_seconds=source_audio_duration_seconds,
)
if context.continued_mode and progress.write_state.output_path_for_write != context.output_path and callable(getattr(context.system_handler, "move_continue_cache", None)):
context.system_handler.move_continue_cache(context.output_path, progress.write_state.output_path_for_write)
last_frame_tensor = progress.write_state.write_chunk(process_is_hdr=False, video_tensor_hdr=None, video_tensor_uint8=video_tensor_uint8, start_frame=write_start, frame_count=frames_to_write)
progress.written_unique_frames += frames_to_write
self.preview_state["image"] = video.frame_to_image(last_frame_tensor)
if progress.continue_cache is not None and hasattr(context.system_handler, "save_continue_cache"):
context.system_handler.save_continue_cache(progress.continue_cache, progress.write_state.output_path_for_write, metadata={"written_unique_frames": int(context.resumed_unique_frames + progress.written_unique_frames), "chunk": int(chunk_index)})
video_tensor_uint8 = None
progress.completed_chunks += 1
if chunk_index < len(context.plans):
progress.current_chunk_display = progress.completed_chunks + 1
yield self.ui_update(status_ui.render_chunk_status_html(context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, "Starting new Chunk", f"Chunk {progress.completed_chunks} finished with {frames_to_write} written frame(s). Preparing next chunk...", continued=context.continued_mode, **context.timing_kwargs(progress.completed_chunks)), self.ui_skip, str(time.time_ns()))
else:
progress.current_chunk_display = progress.completed_chunks
yield self.ui_update(status_ui.render_chunk_status_html(context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, "Chunk Completed", f"Chunk {progress.completed_chunks} finished with {frames_to_write} written frame(s).", continued=context.continued_mode, **context.timing_kwargs(progress.completed_chunks)), self.ui_skip, str(time.time_ns()))
continue
settings = build_task_settings(context.process_settings, is_user_process=context.is_user_process)
chunk_prompt_start_seconds = float(actual_done) / float(context.fps_float)
settings["model_type"] = context.model_type
settings["prompt"] = prompts.resolve_prompt_for_chunk(context.prompt_schedule, chunk_prompt_start_seconds, context.default_prompt_text)
settings["resolution"] = progress.resolved_resolution or context.budget_resolution
settings["video_length"] = model_video_length
settings["sliding_window_overlap"] = plan_overlap_frames if plan_overlap_frames > 0 else 1
settings["image_prompt_type"] = "V" if needs_video_source else ""
settings["audio_prompt_type"] = "K"
if context.is_user_process:
settings["force_fps"] = "control"
settings["video_guide"] = build_virtual_media_path(context.source_path, start_frame=actual_control_start_frame, end_frame=actual_control_end_frame, audio_track_no=context.selected_audio_track)
if context.uses_builtin_outpaint_ui:
settings["video_guide_outpainting_ratio"] = context.active_target_ratio
elif not context.has_outpaint_setting:
settings.pop("video_guide_outpainting_ratio", None)
if context.use_lora_strength_override:
settings["loras_multipliers"] = str(context.active_process_strength)
if needs_video_source:
settings["video_source"] = video.build_process_full_video_source_path(hdr=context.process_is_hdr)
else:
settings.pop("video_source", None)
self.reset_live_chunk_status(context.state)
job = self.api_session.submit_task(settings, callbacks=callbacks)
self.active_job["job"] = job
yield self.ui_update(start_enabled=False, abort_enabled=True)
next_status_refresh_at = 0.0
stop_requested = False
while not job.done:
if self.active_job.get("cancel_requested") and not stop_requested:
try:
job.cancel()
except RuntimeError as exc:
print(f"[Process Full Video] Stop requested; WanGP abort bridge was not available: {exc}")
progress.write_state.stopped = True
stop_requested = True
now = time.monotonic()
if now >= next_status_refresh_at:
html_value = status_ui.render_chunk_status_html(
context.total_chunks_display,
progress.completed_chunks,
progress.current_chunk_display,
callbacks.phase_label,
callbacks.status_text,
continued=context.continued_mode,
phase_current_step=callbacks.current_step,
phase_total_steps=callbacks.total_steps,
prefer_status_phase=True,
**context.timing_kwargs(progress.completed_chunks, callbacks.current_step, callbacks.total_steps),
)
next_status_refresh_at = now + ui_constants.STATUS_REFRESH_INTERVAL_SECONDS
if html_value != last_html:
last_html = html_value
yield self.ui_update(html_value)
time.sleep(0.1)
try:
result = job.result()
finally:
if self.active_job.get("job") is job:
self.active_job["job"] = None
yield self.ui_update(start_enabled=False, abort_enabled=False)
if not result.success:
if result.cancelled:
progress.write_state.stopped = True
break
errors = list(result.errors or [])
raise gr.Error(str(errors[0] if errors else f"Chunk {chunk_index} failed."))
progress.chunk_output_paths.extend(
str(Path(path).resolve())
for path in result.generated_files
if isinstance(path, str) and len(path.strip()) > 0 and str(Path(path).resolve()) not in progress.chunk_output_paths
)
progress.last_segment_path = media.get_last_generated_video_path(list(result.generated_files)) or progress.last_segment_path
returned_video_item = next((item for item in result.artifacts if item.video_tensor_hdr is not None), None) if context.process_is_hdr else next((item for item in result.artifacts if item.video_tensor_uint8 is not None), None)
returned_tensor = None if returned_video_item is None else (returned_video_item.video_tensor_hdr if context.process_is_hdr else returned_video_item.video_tensor_uint8)
if returned_video_item is None or not torch.is_tensor(returned_tensor):
raise gr.Error(f"Chunk {chunk_index} completed without returned video tensor data.")
video_tensor_hdr = returned_tensor.detach().cpu() if context.process_is_hdr else None
video_tensor_uint8 = tonemap_hdr_tensor_to_uint8(video_tensor_hdr) if context.process_is_hdr else returned_tensor.detach().cpu()
returned_frame_count = int(video_tensor_uint8.shape[1])
expected_frame_count = plan_requested_frames
minimum_returned_frames = expected_frame_count - 1 if expected_frame_count > 1 else 1
if not context.process_is_hdr and returned_frame_count < minimum_returned_frames:
video_candidates = [path for path in result.generated_files if isinstance(path, str) and os.path.isfile(path) and str(Path(path).suffix).lower() in {".mp4", ".mkv", ".mov", ".avi"}]
if video_candidates:
decoded_tensor = video.load_video_tensor_from_file(video_candidates[0])
decoded_frame_count = int(decoded_tensor.shape[1])
print(f"[Process Full Video] Chunk {chunk_index}: returned video tensor has {returned_frame_count} frame(s); decoded chunk file has {decoded_frame_count} frame(s)")
if decoded_frame_count >= minimum_returned_frames:
video_tensor_uint8 = decoded_tensor
returned_frame_count = decoded_frame_count
print(f"[Process Full Video] Chunk {chunk_index}: returned video tensor has {returned_frame_count} frame(s); control video lasts {expected_frame_count} frame(s)")
chunk_width, chunk_height = video.get_video_tensor_resolution(video_tensor_uint8)
chunk_resolution = f"{chunk_width}x{chunk_height}"
print(f"[Process Full Video] Chunk {chunk_index}: generated chunk resolution {chunk_resolution}")
if len(progress.resolved_resolution) == 0:
progress.resolved_resolution = chunk_resolution
progress.resolved_width = chunk_width
progress.resolved_height = chunk_height
elif chunk_resolution != progress.resolved_resolution:
raise gr.Error(f"Chunk {chunk_index} changed output resolution from {progress.resolved_resolution} to {chunk_resolution}.")
skip_frames = plan_overlap_frames
remaining_unique_frames = context.requested_unique_frames - (context.resumed_unique_frames + progress.written_unique_frames)
expected_unique_frames = plan_requested_frames - skip_frames
if expected_unique_frames <= 0:
raise gr.Error(f"Chunk {chunk_index} has no writable frame in the computed plan.")
if expected_unique_frames > remaining_unique_frames:
raise gr.Error(f"Chunk {chunk_index} would write {expected_unique_frames} frame(s), but only {remaining_unique_frames} frame(s) remain.")
writable_frame_count = int(video_tensor_uint8.shape[1]) - skip_frames
if writable_frame_count < expected_unique_frames:
raise gr.Error(f"Chunk {chunk_index} returned {writable_frame_count} writable frame(s), but {expected_unique_frames} frame(s) were required.")
frames_to_write = expected_unique_frames
source_audio_duration_seconds = float(frames.count_planned_unique_frames(context.plans)) / float(context.fps_float) if context.use_live_av_mux else None
progress.write_state.ensure_started(
server_config=self.plugin.server_config,
ffmpeg_path=context.ffmpeg_path,
process_is_hdr=context.process_is_hdr,
use_live_av_mux=context.use_live_av_mux,
output_container=context.output_container,
source_path=context.source_path,
exact_start_seconds=context.exact_start_seconds,
selected_audio_track=context.selected_audio_track,
resolved_width=progress.resolved_width,
resolved_height=progress.resolved_height,
fps_float=context.fps_float,
source_audio_duration_seconds=source_audio_duration_seconds,
)
last_frame_tensor = progress.write_state.write_chunk(process_is_hdr=context.process_is_hdr, video_tensor_hdr=video_tensor_hdr, video_tensor_uint8=video_tensor_uint8, start_frame=skip_frames, frame_count=frames_to_write)
progress.written_unique_frames += frames_to_write
self.preview_state["image"] = video.frame_to_image(last_frame_tensor)
overlap_source_tensor = video_tensor_hdr if context.process_is_hdr and video_tensor_hdr is not None else video_tensor_uint8
next_overlap_tensor = video.update_process_full_video_overlap_buffer(overlap_source_tensor[:, skip_frames:skip_frames + frames_to_write], context.overlap_frames, context.processing_fps, hdr=context.process_is_hdr)
if next_overlap_tensor is not None and int(next_overlap_tensor.shape[1]) > 0:
next_overlap_count = int(next_overlap_tensor.shape[1])
next_overlap_start_frame = context.start_frame + context.resumed_unique_frames + progress.written_unique_frames - next_overlap_count
print(f"[Process Full Video] Chunk {chunk_index}: next overlap buffer {frames.describe_frame_range(next_overlap_start_frame, next_overlap_count)}")
progress.completed_chunks += 1
progress.chunk_output_paths = media.delete_released_chunk_outputs(context.state, progress.chunk_output_paths, preserve_paths=[progress.last_segment_path] if progress.last_segment_path else None)
if chunk_index < len(context.plans):
progress.current_chunk_display = progress.completed_chunks + 1
yield self.ui_update(status_ui.render_chunk_status_html(context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, "Starting new Chunk", f"Chunk {progress.completed_chunks} finished with {frames_to_write} written frame(s). Preparing next chunk...", continued=context.continued_mode, **context.timing_kwargs(progress.completed_chunks)), self.ui_skip, str(time.time_ns()))
else:
progress.current_chunk_display = progress.completed_chunks
yield self.ui_update(status_ui.render_chunk_status_html(context.total_chunks_display, progress.completed_chunks, progress.current_chunk_display, "Chunk Completed", f"Chunk {progress.completed_chunks} finished with {frames_to_write} written frame(s).", continued=context.continued_mode, **context.timing_kwargs(progress.completed_chunks)), self.ui_skip, str(time.time_ns()))
return ChunkExecutionResult(
written_unique_frames=progress.written_unique_frames,
completed_chunks=progress.completed_chunks,
current_chunk_display=progress.current_chunk_display,
resolved_resolution=progress.resolved_resolution,
resolved_width=progress.resolved_width,
resolved_height=progress.resolved_height,
last_segment_path=progress.last_segment_path,
chunk_output_paths=progress.chunk_output_paths,
continue_cache=progress.continue_cache,
)