|
|
from time import perf_counter |
|
|
from typing import Any, Iterator |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
from modules import scripts, shared as webui_shared |
|
|
from modules.options import Options |
|
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, fix_seed |
|
|
from modules.shared_state import State |
|
|
from modules.styles import StyleDatabase |
|
|
|
|
|
from temporal.interop import EXTENSION_DIR, get_cn_units |
|
|
from temporal.pipeline_modules.measuring import MeasuringModule |
|
|
from temporal.preset import Preset |
|
|
from temporal.project import Project |
|
|
from temporal.shared import shared |
|
|
from temporal.ui import CallbackInputs, CallbackOutputs, UI |
|
|
from temporal.ui.fs_store_list import FSStoreList, FSStoreListEntry |
|
|
from temporal.ui.gradio_widget import GradioWidget |
|
|
from temporal.ui.options_editor import OptionsEditor |
|
|
from temporal.ui.paginator import Paginator |
|
|
from temporal.ui.project_editor import ProjectEditor |
|
|
from temporal.ui.video_renderer_editor import VideoRendererEditor |
|
|
from temporal.utils import logging |
|
|
from temporal.utils.fs import load_text |
|
|
from temporal.utils.image import PILImage, ensure_image_dims, np_to_pil, pil_to_np |
|
|
from temporal.utils.object import copy_with_overrides |
|
|
from temporal.utils.time import wait_until |
|
|
from temporal.video_renderer import video_render_queue |
|
|
from temporal.web_ui import process_images |
|
|
|
|
|
|
|
|
|
|
|
opts: Options = getattr(webui_shared, "opts") |
|
|
prompt_styles: StyleDatabase = getattr(webui_shared, "prompt_styles") |
|
|
state: State = getattr(webui_shared, "state") |
|
|
|
|
|
|
|
|
class TemporalScript(scripts.Script): |
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
|
super().__init__(*args, **kwargs) |
|
|
shared.init(EXTENSION_DIR / "settings", EXTENSION_DIR / "presets") |
|
|
|
|
|
def title(self) -> str: |
|
|
return "Temporal" |
|
|
|
|
|
def show(self, is_img2img: bool) -> Any: |
|
|
return is_img2img |
|
|
|
|
|
def ui(self, is_img2img: bool) -> Any: |
|
|
self._ui = UI() |
|
|
|
|
|
stored_preset = FSStoreList(label = "Preset", store = shared.preset_store, features = ["load", "save", "rename", "delete"]) |
|
|
stored_project = FSStoreList(label = "Project", store = shared.project_store, features = ["load", "rename", "delete"]) |
|
|
|
|
|
with GradioWidget(gr.Tab, label = "General"): |
|
|
load_parameters = GradioWidget(gr.Checkbox, label = "Load parameters", value = True) |
|
|
continue_from_last_frame = GradioWidget(gr.Checkbox, label = "Continue from last frame", value = True) |
|
|
iter_count = GradioWidget(gr.Number, label = "Iteration count", precision = 0, minimum = 1, step = 1, value = 100) |
|
|
|
|
|
with GradioWidget(gr.Tab, label = "Information"): |
|
|
description = GradioWidget(gr.Textbox, label = "Description", lines = 5, max_lines = 5, interactive = False) |
|
|
gallery = GradioWidget(gr.Gallery, label = "Gallery", columns = 4, object_fit = "contain", preview = True) |
|
|
gallery_page = Paginator(label = "Page", minimum = 1, value = 1) |
|
|
gallery_parallel = Paginator(label = "Parallel", minimum = 1, value = 1) |
|
|
|
|
|
with GradioWidget(gr.Tab, label = "Pipeline"): |
|
|
project = ProjectEditor() |
|
|
|
|
|
with GradioWidget(gr.Tab, label = "Video Rendering"): |
|
|
video_renderer = VideoRendererEditor(value = shared.video_renderer) |
|
|
video_parallel_index = GradioWidget(gr.Number, label = "Parallel index", precision = 0, minimum = 1, step = 1, value = 1) |
|
|
|
|
|
with GradioWidget(gr.Row): |
|
|
render_draft = GradioWidget(gr.Button, value = "Render draft") |
|
|
render_final = GradioWidget(gr.Button, value = "Render final") |
|
|
|
|
|
video_preview = GradioWidget(gr.Video, label = "Preview", format = "mp4", interactive = False) |
|
|
|
|
|
with GradioWidget(gr.Tab, label = "Measuring"): |
|
|
measuring_parallel_index = GradioWidget(gr.Number, label = "Parallel index", precision = 0, minimum = 1, step = 1, value = 1) |
|
|
render_graphs = GradioWidget(gr.Button, value = "Render graphs") |
|
|
graph_gallery = GradioWidget(gr.Gallery, label = "Graphs", columns = 4, object_fit = "contain", preview = True) |
|
|
|
|
|
with GradioWidget(gr.Tab, label = "Tools"): |
|
|
delete_intermediate_frames = GradioWidget(gr.Button, value = "Delete intermediate frames") |
|
|
delete_session_data = GradioWidget(gr.Button, value = "Delete session data") |
|
|
|
|
|
with GradioWidget(gr.Tab, label = "Settings"): |
|
|
apply_settings = GradioWidget(gr.Button, value = "Apply") |
|
|
options = OptionsEditor(value = shared.options) |
|
|
|
|
|
with GradioWidget(gr.Tab, label = "Help"): |
|
|
for file_name, title in [ |
|
|
("main.md", "Main"), |
|
|
("tab_project.md", "Project tab"), |
|
|
("tab_pipeline.md", "Pipeline tab"), |
|
|
("tab_video_rendering.md", "Video Rendering tab"), |
|
|
("tab_measuring.md", "Measuring tab"), |
|
|
("tab_settings.md", "Settings tab"), |
|
|
]: |
|
|
with GradioWidget(gr.Accordion, label = title, open = False): |
|
|
GradioWidget(gr.Markdown, value = load_text(EXTENSION_DIR / "docs" / "temporal" / file_name, "")) |
|
|
|
|
|
@stored_preset.callback("load", [stored_preset], [stored_project, load_parameters, continue_from_last_frame, iter_count, project, video_renderer]) |
|
|
def _(inputs: CallbackInputs) -> CallbackOutputs: |
|
|
data = inputs[stored_preset].data.data |
|
|
|
|
|
return { |
|
|
stored_project: {"value": data["stored_project"]}, |
|
|
load_parameters: {"value": data["load_parameters"]}, |
|
|
continue_from_last_frame: {"value": data["continue_from_last_frame"]}, |
|
|
iter_count: {"value": data["iter_count"]}, |
|
|
project: {"value": data["project"], "preview_states": data["preview_states"]}, |
|
|
video_renderer: {"value": data["video_renderer"]}, |
|
|
} |
|
|
|
|
|
@stored_preset.callback("save", [stored_project, load_parameters, continue_from_last_frame, iter_count, project, video_renderer], [stored_preset]) |
|
|
def _(inputs: CallbackInputs) -> CallbackOutputs: |
|
|
return {stored_preset: {"value": Preset({ |
|
|
"stored_project": inputs[stored_project].name, |
|
|
"load_parameters": inputs[load_parameters], |
|
|
"continue_from_last_frame": inputs[continue_from_last_frame], |
|
|
"iter_count": inputs[iter_count], |
|
|
"project": inputs[project], |
|
|
"preview_states": shared.previewed_modules, |
|
|
"video_renderer": inputs[video_renderer], |
|
|
})}} |
|
|
|
|
|
@stored_project.callback("change", [stored_project], [description, gallery, gallery_page, gallery_parallel]) |
|
|
def _(inputs: CallbackInputs) -> CallbackOutputs: |
|
|
project_obj = inputs[stored_project].data |
|
|
|
|
|
return { |
|
|
description: {"value": project_obj.get_description()}, |
|
|
gallery: {"value": project_obj.list_all_frame_paths()[:shared.options.ui.gallery_size]}, |
|
|
gallery_page: {"value": 1}, |
|
|
gallery_parallel: {"value": 1}, |
|
|
} |
|
|
|
|
|
@stored_project.callback("load", [stored_project], [project]) |
|
|
def _(inputs: CallbackInputs) -> CallbackOutputs: |
|
|
return {project: {"value": inputs[stored_project].data}} |
|
|
|
|
|
@gallery_page.callback("change", [stored_project, gallery_page, gallery_parallel], [gallery]) |
|
|
@gallery_parallel.callback("change", [stored_project, gallery_page, gallery_parallel], [gallery]) |
|
|
def _(inputs: CallbackInputs) -> CallbackOutputs: |
|
|
project_obj = inputs[stored_project].data |
|
|
page = inputs[gallery_page] |
|
|
parallel = inputs[gallery_parallel] |
|
|
gallery_size = shared.options.ui.gallery_size |
|
|
|
|
|
return {gallery: {"value": project_obj.list_all_frame_paths(parallel)[(page - 1) * gallery_size:page * gallery_size]}} |
|
|
|
|
|
def render_video(inputs: CallbackInputs, is_final: bool) -> Iterator[CallbackOutputs]: |
|
|
yield { |
|
|
render_draft: {"interactive": False}, |
|
|
render_final: {"interactive": False}, |
|
|
} |
|
|
|
|
|
shared.video_renderer = inputs[video_renderer] |
|
|
|
|
|
video_path = inputs[stored_project].data.render_video(shared.video_renderer, is_final, inputs[video_parallel_index]) |
|
|
wait_until(lambda: not video_render_queue.busy) |
|
|
|
|
|
yield { |
|
|
render_draft: {"interactive": True}, |
|
|
render_final: {"interactive": True}, |
|
|
video_preview: {"value": video_path.as_posix()}, |
|
|
} |
|
|
|
|
|
@render_draft.callback("click", [stored_project, video_renderer, video_parallel_index], [render_draft, render_final, video_preview]) |
|
|
def _(inputs: CallbackInputs) -> Iterator[CallbackOutputs]: |
|
|
yield from render_video(inputs, False) |
|
|
|
|
|
@render_final.callback("click", [stored_project, video_renderer, video_parallel_index], [render_draft, render_final, video_preview]) |
|
|
def _(inputs: CallbackInputs) -> Iterator[CallbackOutputs]: |
|
|
yield from render_video(inputs, True) |
|
|
|
|
|
@render_graphs.callback("click", [stored_project, measuring_parallel_index], [graph_gallery]) |
|
|
def _(inputs: CallbackInputs) -> CallbackOutputs: |
|
|
return {graph_gallery: {"value": [ |
|
|
x.plot(inputs[measuring_parallel_index] - 1) |
|
|
for x in inputs[stored_project].data.pipeline.modules |
|
|
if isinstance(x, MeasuringModule) and x.enabled |
|
|
]}} |
|
|
|
|
|
@delete_intermediate_frames.callback("click", [stored_project], [description, gallery]) |
|
|
def _(inputs: CallbackInputs) -> CallbackOutputs: |
|
|
project_obj = inputs[stored_project].data |
|
|
project_obj.delete_intermediate_frames() |
|
|
|
|
|
return { |
|
|
description: {"value": project_obj.get_description()}, |
|
|
gallery: {"value": project_obj.list_all_frame_paths()[:shared.options.ui.gallery_size]}, |
|
|
} |
|
|
|
|
|
@delete_session_data.callback("click", [stored_project], []) |
|
|
def _(inputs: CallbackInputs) -> CallbackOutputs: |
|
|
project_obj = inputs[stored_project].data |
|
|
project_obj.delete_session_data() |
|
|
project_obj.save(project_obj.path) |
|
|
|
|
|
return {} |
|
|
|
|
|
@apply_settings.callback("click", [options], []) |
|
|
def _(inputs: CallbackInputs) -> CallbackOutputs: |
|
|
shared.options = inputs[options] |
|
|
shared.options.save(EXTENSION_DIR / "settings") |
|
|
|
|
|
return {} |
|
|
|
|
|
return self._ui.finalize(stored_project, load_parameters, continue_from_last_frame, iter_count, project) |
|
|
|
|
|
def run(self, p: StableDiffusionProcessingImg2Img, *args: Any) -> Any: |
|
|
stored_project: FSStoreListEntry[Project] |
|
|
load_parameters: bool |
|
|
continue_from_last_frame: bool |
|
|
iter_count: int |
|
|
project: Project |
|
|
|
|
|
stored_project, load_parameters, continue_from_last_frame, iter_count, project = self._ui.recombine(*args) |
|
|
|
|
|
opts_backup = opts.data.copy() |
|
|
|
|
|
opts.save_to_dirs = False |
|
|
|
|
|
if shared.options.live_preview.show_only_finished_images: |
|
|
opts.show_progress_every_n_steps = -1 |
|
|
|
|
|
p.prompt = prompt_styles.apply_styles_to_prompt(p.prompt, p.styles) |
|
|
p.negative_prompt = prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles) |
|
|
p.styles.clear() |
|
|
|
|
|
fix_seed(p) |
|
|
|
|
|
project.path = stored_project.data.path |
|
|
project.options = opts |
|
|
project.processing = p |
|
|
project.controlnet_units = get_cn_units(p) |
|
|
|
|
|
if load_parameters: |
|
|
project.load(stored_project.data.path) |
|
|
|
|
|
if not continue_from_last_frame: |
|
|
project.delete_all_frames() |
|
|
project.delete_session_data() |
|
|
|
|
|
if not p.init_images or not isinstance(p.init_images[0], PILImage): |
|
|
noises = [ |
|
|
project.initial_noise.noise.generate((p.height, p.width, 3), p.seed, i) |
|
|
for i in range(project.pipeline.parallel) |
|
|
] |
|
|
|
|
|
if project.initial_noise.factor < 1.0: |
|
|
if not (processed_images := process_images( |
|
|
copy_with_overrides(p, |
|
|
denoising_strength = 1.0 - project.initial_noise.factor, |
|
|
do_not_save_samples = True, |
|
|
do_not_save_grid = True, |
|
|
), |
|
|
[(np_to_pil(x), p.seed + i, 1) for i, x in enumerate(noises)], |
|
|
shared.options.processing.pixels_per_batch, |
|
|
True, |
|
|
)): |
|
|
opts.data.update(opts_backup) |
|
|
|
|
|
return Processed(p, p.init_images) |
|
|
|
|
|
p.init_images = [image_array[0] for image_array in processed_images] |
|
|
|
|
|
else: |
|
|
p.init_images = noises |
|
|
|
|
|
elif len(p.init_images) != project.pipeline.parallel: |
|
|
p.init_images = [p.init_images[0]] * project.pipeline.parallel |
|
|
|
|
|
if not project.iteration.images: |
|
|
project.iteration.images[:] = [pil_to_np(ensure_image_dims(x, "RGB", (p.width, p.height))) for x in p.init_images] |
|
|
|
|
|
last_images = project.iteration.images.copy() |
|
|
|
|
|
state.job_count = iter_count |
|
|
|
|
|
for i in range(iter_count): |
|
|
logging.info(f"Iteration {i + 1} / {iter_count}") |
|
|
|
|
|
start_time = perf_counter() |
|
|
|
|
|
state.job = "Temporal main loop" |
|
|
state.job_no = i |
|
|
|
|
|
if not project.pipeline.run(project): |
|
|
break |
|
|
|
|
|
last_images = project.iteration.images.copy() |
|
|
|
|
|
if i % shared.options.output.autosave_every_n_iterations == 0: |
|
|
project.save(project.path) |
|
|
|
|
|
end_time = perf_counter() |
|
|
|
|
|
logging.info(f"Iteration took {end_time - start_time:.6f} second(s)") |
|
|
|
|
|
project.pipeline.finalize(project) |
|
|
project.save(project.path) |
|
|
|
|
|
state.end() |
|
|
|
|
|
opts.data.update(opts_backup) |
|
|
|
|
|
return Processed(p, [np_to_pil(x) for x in last_images]) |
|
|
|