| import copy |
| from dataclasses import dataclass |
|
|
| import comfy |
| import torch |
|
|
| from .easy_samplers import LTXVBaseSampler, LTXVExtendSampler, LTXVInContextSampler |
| from .latents import LTXVDilateLatent, LTXVSelectLatents |
| from .nodes_registry import comfy_node |
|
|
|
|
| @dataclass |
| class TileConfig: |
| """Configuration for spatial tile processing.""" |
|
|
| tile_latents: dict |
| tile_guiding_latents: dict |
| tile_negative_index_latents: dict |
| tile_keyframes: torch.Tensor |
| keyframe_per_tile_indices: list |
| tile_normalizing_latents: dict |
| tile_height: int |
| tile_width: int |
| v: int |
| h: int |
| vertical_tiles: int |
| horizontal_tiles: int |
| first_seed: int |
|
|
|
|
| @dataclass |
| class SamplingConfig: |
| """Configuration for sampling parameters.""" |
|
|
| temporal_tile_size: int |
| temporal_overlap: int |
| temporal_overlap_cond_strength: float |
| cond_image_strength: float |
| guiding_strength: float |
| adain_factor: float |
| optional_negative_index: int |
| optional_negative_index_strength: float |
| optional_positive_conditionings: list |
| time_scale_factor: int |
| width_scale_factor: int |
| height_scale_factor: int |
| per_tile_seed_offsets: list |
| guiding_start_step: int |
| guiding_end_step: int |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| """Configuration for model components.""" |
|
|
| model: object |
| vae: object |
| noise: object |
| sampler: object |
| sigmas: object |
| guider: object |
|
|
|
|
| @comfy_node( |
| name="LTXVLoopingSampler", |
| ) |
| class LTXVLoopingSampler: |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("MODEL", {"tooltip": "The diffusion model to use."}), |
| "vae": ("VAE", {"tooltip": "The VAE to use."}), |
| "noise": ("NOISE", {"tooltip": "The noise to use."}), |
| "sampler": ("SAMPLER", {"tooltip": "The sampler to use."}), |
| "sigmas": ("SIGMAS", {"tooltip": "The sigmas to use."}), |
| "guider": ( |
| "GUIDER", |
| {"tooltip": "The guider to use, must be a STGGuiderAdvanced."}, |
| ), |
| "latents": ( |
| "LATENT", |
| { |
| "tooltip": "The latents to use for creating the long video, they can be guiding latents or empty latents when no guidance is used." |
| }, |
| ), |
| "temporal_tile_size": ( |
| "INT", |
| { |
| "default": 80, |
| "min": 24, |
| "max": 1000, |
| "step": 8, |
| "tooltip": "The size of the temporal tile to use for the sampling, in pixel frames, in addition to the overlapping region.", |
| }, |
| ), |
| "temporal_overlap": ( |
| "INT", |
| { |
| "default": 24, |
| "min": 16, |
| "max": 80, |
| "step": 8, |
| "tooltip": "The overlap between the temporal tiles, in pixel frames.", |
| }, |
| ), |
| "guiding_strength": ( |
| "FLOAT", |
| { |
| "default": 1.0, |
| "min": 0.0, |
| "max": 1.0, |
| "step": 0.01, |
| "tooltip": "The strength of the conditioning on guiding latents, when optional_guiding_latents are provided.", |
| }, |
| ), |
| "temporal_overlap_cond_strength": ( |
| "FLOAT", |
| { |
| "default": 0.5, |
| "min": 0.0, |
| "max": 1.0, |
| "step": 0.01, |
| "tooltip": "The strength of the conditioning on the latents from the previous temporal tile.", |
| }, |
| ), |
| "cond_image_strength": ( |
| "FLOAT", |
| { |
| "default": 1.0, |
| "min": 0.0, |
| "max": 1.0, |
| "step": 0.01, |
| "tooltip": "The strength of the conditioning on the optional_cond_images.", |
| }, |
| ), |
| "horizontal_tiles": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 6, |
| "tooltip": "Number of horizontal spatial tiles.", |
| }, |
| ), |
| "vertical_tiles": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 6, |
| "tooltip": "Number of vertical spatial tiles.", |
| }, |
| ), |
| "spatial_overlap": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 8, |
| "tooltip": "Overlap between spatial tiles.", |
| }, |
| ), |
| }, |
| "optional": { |
| "optional_cond_images": ( |
| "IMAGE", |
| { |
| "tooltip": "The image to use for conditioning the first frame in the video (i2v setup). If not provided, the first frame will be unconditioned (t2v setup). The image will be resized to the size of the first frame." |
| }, |
| ), |
| "optional_guiding_latents": ( |
| "LATENT", |
| { |
| "tooltip": "The latents to use for guiding the sampling, typically with an IC-LoRA." |
| }, |
| ), |
| "adain_factor": ( |
| "FLOAT", |
| { |
| "default": 0.0, |
| "min": 0.0, |
| "max": 1.0, |
| "step": 0.01, |
| "tooltip": "The strength of the AdaIn operation used to fix the statistics of each new generated temporal tile, to prevent accumulated oversaturation.", |
| }, |
| ), |
| "optional_positive_conditionings": ( |
| "CONDITIONING", |
| { |
| "tooltip": "Optional way to provide changing positive prompts, one per temporal tile, using the MultiPromptProvider node." |
| }, |
| ), |
| "optional_negative_index_latents": ( |
| "LATENT", |
| { |
| "tooltip": "Special optional latents to condition on a negative index before each new temporal tile as a way to provide long term context during video generation." |
| }, |
| ), |
| "guiding_start_step": ( |
| "INT", |
| { |
| "default": 0, |
| "min": 0, |
| "max": 1000, |
| "tooltip": "The step at which the guiding latents start to be used.", |
| }, |
| ), |
| "guiding_end_step": ( |
| "INT", |
| { |
| "default": 1000, |
| "min": 0, |
| "max": 1000, |
| "tooltip": "The step at which the guiding latents stop to be used.", |
| }, |
| ), |
| "optional_cond_image_indices": ( |
| "STRING", |
| { |
| "default": "0", |
| "tooltip": "The indices of the keyframes to use for the sampling, separated by commas. The indices are counted from the end of the video.", |
| }, |
| ), |
| "optional_normalizing_latents": ( |
| "LATENT", |
| { |
| "tooltip": "The latents to use for normalizing the output latents, they will be used to normalize the output latents to the same statistics as the input latents." |
| }, |
| ), |
| }, |
| } |
|
|
| RETURN_TYPES = ("LATENT",) |
| RETURN_NAMES = ("denoised_output",) |
|
|
| FUNCTION = "sample" |
| CATEGORY = "sampling" |
|
|
| def _extract_latent_spatial_tile(self, latent_dict, v_start, v_end, h_start, h_end): |
| """Extract spatial tile from a latent dictionary.""" |
| if latent_dict is None: |
| return None |
| tile_samples = latent_dict["samples"][:, :, :, v_start:v_end, h_start:h_end] |
| if "noise_mask" in latent_dict and latent_dict["noise_mask"] is not None: |
| tile_masks = latent_dict["noise_mask"][ |
| :, :, :, v_start:v_end, h_start:h_end |
| ] |
| return {"samples": tile_samples, "noise_mask": tile_masks} |
| else: |
| return {"samples": tile_samples} |
|
|
| def _extract_spatial_tile( |
| self, |
| latents, |
| optional_guiding_latents, |
| optional_negative_index_latents, |
| optional_normalizing_latents, |
| optional_keyframes, |
| v_start, |
| v_end, |
| h_start, |
| h_end, |
| height_scale_factor, |
| width_scale_factor, |
| ): |
| """Extract spatial tiles from all inputs for a given spatial region.""" |
| |
| tile_latents = self._extract_latent_spatial_tile( |
| latents, v_start, v_end, h_start, h_end |
| ) |
|
|
| |
| tile_guiding_latents = self._extract_latent_spatial_tile( |
| optional_guiding_latents, v_start, v_end, h_start, h_end |
| ) |
|
|
| |
| tile_negative_index_latents = self._extract_latent_spatial_tile( |
| optional_negative_index_latents, v_start, v_end, h_start, h_end |
| ) |
|
|
| |
| tile_normalizing_latents = self._extract_latent_spatial_tile( |
| optional_normalizing_latents, v_start, v_end, h_start, h_end |
| ) |
|
|
| if optional_keyframes is not None: |
| |
| img_h_start = v_start * height_scale_factor |
| img_h_end = v_end * height_scale_factor |
| img_w_start = h_start * width_scale_factor |
| img_w_end = h_end * width_scale_factor |
|
|
| tile_keyframes = optional_keyframes[ |
| :, img_h_start:img_h_end, img_w_start:img_w_end, : |
| ] |
| else: |
| tile_keyframes = None |
|
|
| return ( |
| tile_latents, |
| tile_guiding_latents, |
| tile_negative_index_latents, |
| tile_keyframes, |
| tile_normalizing_latents, |
| ) |
|
|
| def _process_temporal_chunks( |
| self, |
| tile_config: TileConfig, |
| sampling_config: SamplingConfig, |
| model_config: ModelConfig, |
| ): |
| """Process all temporal chunks for a single spatial tile.""" |
| chunk_index = 0 |
| tile_out_latents = None |
| first_tile_out_latents = None |
|
|
| for i_temporal_tile, (start_index, end_index) in enumerate( |
| zip( |
| range( |
| 0, |
| tile_config.tile_latents["samples"].shape[2] |
| + sampling_config.temporal_tile_size |
| - sampling_config.temporal_overlap, |
| sampling_config.temporal_tile_size |
| - sampling_config.temporal_overlap, |
| ), |
| range( |
| sampling_config.temporal_tile_size, |
| tile_config.tile_latents["samples"].shape[2] |
| + sampling_config.temporal_tile_size |
| - sampling_config.temporal_overlap, |
| sampling_config.temporal_tile_size |
| - sampling_config.temporal_overlap, |
| ), |
| ) |
| ): |
| if tile_config.tile_guiding_latents is not None: |
| guiding_latent_chunk = LTXVSelectLatents().select_latents( |
| tile_config.tile_guiding_latents, |
| start_index, |
| min( |
| end_index - 1, |
| tile_config.tile_guiding_latents["samples"].shape[2] - 1, |
| ), |
| )[0] |
| else: |
| guiding_latent_chunk = None |
|
|
| latent_chunk = LTXVSelectLatents().select_latents( |
| tile_config.tile_latents, |
| start_index, |
| min(end_index - 1, tile_config.tile_latents["samples"].shape[2] - 1), |
| )[0] |
|
|
| print( |
| "Processing temporal chunk at index", |
| start_index, |
| "to", |
| min(end_index - 1, tile_config.tile_latents["samples"].shape[2] - 1), |
| ) |
|
|
| seed_offset = self._get_per_tile_value( |
| sampling_config.per_tile_seed_offsets, i_temporal_tile |
| ) |
|
|
| if tile_config.tile_normalizing_latents is not None: |
| normalizing_latent_chunk = LTXVSelectLatents().select_latents( |
| tile_config.tile_normalizing_latents, |
| start_index, |
| min( |
| end_index - 1, |
| tile_config.tile_normalizing_latents["samples"].shape[2] - 1, |
| ), |
| )[0] |
| normalize_per_frame = True |
| print( |
| "Normalizing latents provided, normalizing per frame and channel with factor", |
| sampling_config.adain_factor, |
| ) |
| else: |
| normalizing_latent_chunk = first_tile_out_latents |
| normalize_per_frame = False |
| print( |
| "No normalizing latents provided, normalizing per channel using first chunk with factor", |
| sampling_config.adain_factor, |
| ) |
|
|
| model_config.noise.seed = self._calculate_tile_seed( |
| tile_config.first_seed, |
| start_index, |
| tile_config.vertical_tiles, |
| tile_config.horizontal_tiles, |
| tile_config.v, |
| tile_config.h, |
| seed_offset, |
| ) |
|
|
| |
| new_guider = self._prepare_guider_for_chunk( |
| model_config.guider, |
| sampling_config.optional_positive_conditionings, |
| chunk_index, |
| ) |
|
|
| this_chunk_keyframe_indices = [ |
| in_tile_index |
| for (tile_index, in_tile_index) in tile_config.keyframe_per_tile_indices |
| if tile_index == i_temporal_tile |
| ] |
| if this_chunk_keyframe_indices and tile_config.tile_keyframes is not None: |
| this_chunk_keyframes = torch.cat( |
| [ |
| tile_config.tile_keyframes[i_keyframe].unsqueeze(0) |
| for (i_keyframe, (tile_index, in_tile_index)) in enumerate( |
| tile_config.keyframe_per_tile_indices |
| ) |
| if tile_index == i_temporal_tile |
| ] |
| ) |
| print( |
| f"Chunk {i_temporal_tile} keyframe indices: {this_chunk_keyframe_indices}" |
| ) |
| else: |
| this_chunk_keyframes = None |
| print(f"Chunk {i_temporal_tile} has no keyframes") |
| this_chunk_keyframe_indices = ",".join( |
| [str(i) for i in this_chunk_keyframe_indices] |
| ) |
| if start_index == 0: |
| if tile_config.tile_guiding_latents is not None: |
| tile_out_latents = LTXVInContextSampler().sample( |
| vae=model_config.vae, |
| guider=new_guider, |
| sampler=model_config.sampler, |
| sigmas=model_config.sigmas, |
| noise=model_config.noise, |
| guiding_latents=guiding_latent_chunk, |
| optional_cond_images=this_chunk_keyframes, |
| optional_cond_indices=this_chunk_keyframe_indices, |
| num_frames=-1, |
| optional_negative_index_latents=tile_config.tile_negative_index_latents, |
| optional_negative_index=sampling_config.optional_negative_index, |
| optional_negative_index_strength=sampling_config.optional_negative_index_strength, |
| optional_initialization_latents=latent_chunk, |
| cond_image_strength=sampling_config.cond_image_strength, |
| guiding_strength=sampling_config.guiding_strength, |
| guiding_start_step=sampling_config.guiding_start_step, |
| guiding_end_step=sampling_config.guiding_end_step, |
| )[0] |
| else: |
| tile_out_latents = LTXVBaseSampler().sample( |
| model=model_config.model, |
| vae=model_config.vae, |
| noise=model_config.noise, |
| sampler=model_config.sampler, |
| sigmas=model_config.sigmas, |
| guider=new_guider, |
| num_frames=( |
| min( |
| sampling_config.temporal_tile_size, |
| tile_config.tile_latents["samples"].shape[2], |
| ) |
| - 1 |
| ) |
| * sampling_config.time_scale_factor |
| + 1, |
| width=tile_config.tile_width |
| * sampling_config.width_scale_factor, |
| height=tile_config.tile_height |
| * sampling_config.height_scale_factor, |
| optional_cond_images=this_chunk_keyframes, |
| optional_cond_indices=this_chunk_keyframe_indices, |
| crop="center", |
| crf=30, |
| strength=sampling_config.cond_image_strength, |
| optional_negative_index_latents=tile_config.tile_negative_index_latents, |
| optional_negative_index=sampling_config.optional_negative_index, |
| optional_negative_index_strength=sampling_config.optional_negative_index_strength, |
| optional_initialization_latents=latent_chunk, |
| guiding_start_step=sampling_config.guiding_start_step, |
| guiding_end_step=sampling_config.guiding_end_step, |
| )[0] |
| first_tile_out_latents = copy.deepcopy(tile_out_latents) |
| else: |
| tile_out_latents = LTXVExtendSampler().sample( |
| model=model_config.model, |
| vae=model_config.vae, |
| sampler=model_config.sampler, |
| sigmas=model_config.sigmas, |
| noise=model_config.noise, |
| latents=tile_out_latents, |
| num_new_frames=( |
| latent_chunk["samples"].shape[2] |
| - sampling_config.temporal_overlap |
| ) |
| * sampling_config.time_scale_factor, |
| frame_overlap=sampling_config.temporal_overlap |
| * sampling_config.time_scale_factor, |
| guider=new_guider, |
| strength=sampling_config.temporal_overlap_cond_strength, |
| guiding_strength=sampling_config.guiding_strength, |
| cond_image_strength=sampling_config.cond_image_strength, |
| optional_guiding_latents=guiding_latent_chunk, |
| optional_cond_images=this_chunk_keyframes, |
| optional_cond_indices=this_chunk_keyframe_indices, |
| optional_reference_latents=normalizing_latent_chunk, |
| normalize_per_frame=normalize_per_frame, |
| adain_factor=sampling_config.adain_factor, |
| optional_negative_index_latents=tile_config.tile_negative_index_latents, |
| optional_negative_index=sampling_config.optional_negative_index, |
| optional_negative_index_strength=sampling_config.optional_negative_index_strength, |
| optional_initialization_latents=latent_chunk, |
| guiding_start_step=sampling_config.guiding_start_step, |
| guiding_end_step=sampling_config.guiding_end_step, |
| )[0] |
|
|
| chunk_index += 1 |
|
|
| return tile_out_latents |
|
|
| def _create_spatial_weights( |
| self, |
| tile_shape, |
| v, |
| h, |
| horizontal_tiles, |
| vertical_tiles, |
| spatial_overlap, |
| device, |
| dtype, |
| ): |
| """Create blending weights for spatial tiles.""" |
| tile_weights = torch.ones(tile_shape, device=device, dtype=dtype) |
|
|
| |
| if h > 0: |
| h_blend = torch.linspace(0, 1, spatial_overlap, device=device, dtype=dtype) |
| tile_weights[:, :, :, :, :spatial_overlap] *= h_blend.view(1, 1, 1, 1, -1) |
| if h < horizontal_tiles - 1: |
| h_blend = torch.linspace(1, 0, spatial_overlap, device=device, dtype=dtype) |
| tile_weights[:, :, :, :, -spatial_overlap:] *= h_blend.view(1, 1, 1, 1, -1) |
|
|
| |
| if v > 0: |
| v_blend = torch.linspace(0, 1, spatial_overlap, device=device, dtype=dtype) |
| tile_weights[:, :, :, :spatial_overlap, :] *= v_blend.view(1, 1, 1, -1, 1) |
| if v < vertical_tiles - 1: |
| v_blend = torch.linspace(1, 0, spatial_overlap, device=device, dtype=dtype) |
| tile_weights[:, :, :, -spatial_overlap:, :] *= v_blend.view(1, 1, 1, -1, 1) |
|
|
| return tile_weights |
|
|
| def _calculate_tile_seed( |
| self, |
| first_seed, |
| start_index, |
| vertical_tiles, |
| horizontal_tiles, |
| v, |
| h, |
| seed_offset, |
| ): |
| """Calculate the seed value for a specific temporal and spatial tile.""" |
| return ( |
| first_seed |
| + start_index * (vertical_tiles * horizontal_tiles) |
| + v * horizontal_tiles |
| + h |
| + seed_offset |
| ) |
|
|
| def _get_per_tile_value(self, value_list, tile_index): |
| """Get a value from a per-tile configuration list, using the last value if the list is shorter.""" |
| return value_list[min(tile_index, len(value_list) - 1)] |
|
|
| def _parse_comma_separated_string( |
| self, config_string, default_value, converter_func, optional_total_size=None |
| ): |
| """Parse a comma-separated configuration string into a list with type conversion.""" |
| if config_string == "": |
| config_string = default_value |
| values = [converter_func(item.strip()) for item in config_string.split(",")] |
|
|
| def handle_negative_index(value): |
| if value < 0: |
| return value + optional_total_size |
| return value |
|
|
| if optional_total_size is not None: |
| values = [handle_negative_index(value) for value in values] |
| return values |
|
|
| def _prepare_guider_for_chunk( |
| self, guider, optional_positive_conditionings, chunk_index |
| ): |
| """Prepare the guider for a specific chunk, handling optional positive conditionings.""" |
| if optional_positive_conditionings is not None: |
| new_guider = copy.copy(guider) |
| positive, negative = guider.raw_conds |
| |
| conditioning_index = min( |
| chunk_index, len(optional_positive_conditionings) - 1 |
| ) |
| new_guider.set_conds( |
| optional_positive_conditionings[conditioning_index], |
| negative, |
| ) |
| new_guider.raw_conds = ( |
| optional_positive_conditionings[conditioning_index], |
| negative, |
| ) |
| return new_guider |
| else: |
| return guider |
|
|
| def _calculate_keyframe_per_tile_indices( |
| self, keyframe_indices, temporal_tile_size, temporal_overlap, num_frames |
| ): |
| """ |
| Calculate which temporal tile each keyframe falls into. |
| |
| Returns a list of tuples (temporal_tile_index, in_tile_index) for each keyframe. |
| - temporal_tile_index: which temporal tile the keyframe falls in |
| - in_tile_index: index within that specific temporal tile |
| |
| First tile: frames [0, temporal_tile_size - 8] (size = temporal_tile_size - 8 + 1) |
| Subsequent tiles follow the pattern from _process_temporal_chunks: |
| - Tile n starts at: n * (temporal_tile_size - temporal_overlap) |
| - Tile n ends at: temporal_tile_size + n * (temporal_tile_size - temporal_overlap) - 1 |
| |
| For subsequent tiles, keyframes with in_tile_index < temporal_overlap |
| don't count as falling in that tile. |
| """ |
| result = [] |
|
|
| for keyframe_index in keyframe_indices: |
| if keyframe_index >= num_frames: |
| print( |
| f"Keyframe index {keyframe_index} is greater than num_frames {num_frames}, skipping" |
| ) |
| continue |
| |
| if keyframe_index < temporal_tile_size - 7: |
| result.append((0, keyframe_index)) |
| continue |
|
|
| |
| |
| |
|
|
| tile_step = temporal_tile_size - temporal_overlap |
| tile_index = 1 |
|
|
| while True: |
| tile_start = tile_index * tile_step - 7 |
| tile_end = temporal_tile_size + tile_index * tile_step - 1 - 7 |
| print( |
| f"Tile {tile_index} starts at {tile_start} and ends at {tile_end}" |
| ) |
|
|
| |
| if keyframe_index <= tile_end: |
| in_tile_index = ( |
| keyframe_index - tile_start - 7 |
| ) |
|
|
| |
| |
| if in_tile_index < temporal_overlap: |
| tile_index -= 1 |
| if tile_index == 0: |
| |
| in_tile_index = keyframe_index |
| else: |
| |
| prev_start = tile_start - tile_step |
| in_tile_index = ( |
| keyframe_index - prev_start - 7 |
| ) |
|
|
| result.append((tile_index, in_tile_index)) |
| break |
|
|
| tile_index += 1 |
|
|
| return result |
|
|
| def sample( |
| self, |
| model, |
| vae, |
| noise, |
| sampler, |
| sigmas, |
| guider, |
| latents, |
| guiding_strength, |
| adain_factor, |
| temporal_tile_size, |
| temporal_overlap, |
| temporal_overlap_cond_strength, |
| horizontal_tiles, |
| vertical_tiles, |
| spatial_overlap, |
| optional_cond_images=None, |
| cond_image_strength=1.0, |
| optional_guiding_latents=None, |
| optional_negative_index_latents=None, |
| optional_negative_index_strength=1.0, |
| optional_positive_conditionings=None, |
| guiding_start_step=0, |
| guiding_end_step=1000, |
| optional_cond_image_indices="0", |
| optional_normalizing_latents=None, |
| per_tile_seed_offsets="0", |
| ): |
| |
| samples = latents["samples"] |
| if ( |
| isinstance(samples, comfy.nested_tensor.NestedTensor) |
| and len(samples.tensors) == 2 |
| ): |
| raise ValueError( |
| "LoopingSampler currently does not support Audio Visual latents. please only use video latents." |
| ) |
| batch, channels, frames, height, width = samples.shape |
| time_scale_factor, width_scale_factor, height_scale_factor = ( |
| vae.downscale_index_formula |
| ) |
| temporal_tile_size = temporal_tile_size // time_scale_factor |
| temporal_overlap = temporal_overlap // time_scale_factor |
| first_seed = noise.seed |
|
|
| per_tile_seed_offsets = self._parse_comma_separated_string( |
| per_tile_seed_offsets, "0", int |
| ) |
|
|
| keyframe_indices = self._parse_comma_separated_string( |
| optional_cond_image_indices, |
| "0", |
| int, |
| optional_total_size=frames * time_scale_factor - 7, |
| ) |
| keyframe_per_tile_indices = self._calculate_keyframe_per_tile_indices( |
| keyframe_indices, |
| temporal_tile_size * time_scale_factor, |
| temporal_overlap * time_scale_factor, |
| frames * time_scale_factor - 7, |
| ) |
| print(f"Keyframe per tile indices: {keyframe_per_tile_indices}") |
| if optional_cond_images is not None: |
| optional_keyframes = ( |
| comfy.utils.common_upscale( |
| optional_cond_images.movedim(-1, 1), |
| width * width_scale_factor, |
| height * height_scale_factor, |
| "bilinear", |
| crop="center", |
| ) |
| .movedim(1, -1) |
| .clamp(0, 1) |
| ) |
| else: |
| optional_keyframes = None |
|
|
| if optional_guiding_latents is not None: |
| guide = optional_guiding_latents["samples"] |
| assert ( |
| samples.shape[2] == guide.shape[2] |
| ), "The number of frames in the latents and optional_guiding_latents must be the same" |
| assert ( |
| samples.shape[3] % guide.shape[3] == 0 |
| and samples.shape[4] % guide.shape[4] == 0 |
| ), "The ratio of the height and width of the latents and optional_guiding_latents must be an integer" |
| grid_size_h = samples.shape[3] // guide.shape[3] |
| grid_size_w = samples.shape[4] // guide.shape[4] |
| optional_guiding_latents = LTXVDilateLatent().dilate_latent( |
| optional_guiding_latents, grid_size_w, grid_size_h |
| )[0] |
|
|
| |
| base_tile_height = ( |
| height + (vertical_tiles - 1) * spatial_overlap |
| ) // vertical_tiles |
| base_tile_width = ( |
| width + (horizontal_tiles - 1) * spatial_overlap |
| ) // horizontal_tiles |
|
|
| |
| final_output = None |
| weights = None |
|
|
| |
| for v in range(vertical_tiles): |
| for h in range(horizontal_tiles): |
| |
| h_start = h * (base_tile_width - spatial_overlap) |
| v_start = v * (base_tile_height - spatial_overlap) |
|
|
| |
| h_end = ( |
| min(h_start + base_tile_width, width) |
| if h < horizontal_tiles - 1 |
| else width |
| ) |
| v_end = ( |
| min(v_start + base_tile_height, height) |
| if v < vertical_tiles - 1 |
| else height |
| ) |
|
|
| |
| tile_height = v_end - v_start |
| tile_width = h_end - h_start |
|
|
| print(f"Processing spatial tile at row {v}, col {h}:") |
| print(f" Position: ({v_start}:{v_end}, {h_start}:{h_end})") |
| print(f" Size: {tile_height}x{tile_width}") |
|
|
| |
| ( |
| tile_latents, |
| tile_guiding_latents, |
| tile_negative_index_latents, |
| tile_keyframes, |
| tile_normalizing_latents, |
| ) = self._extract_spatial_tile( |
| latents, |
| optional_guiding_latents, |
| optional_negative_index_latents, |
| optional_normalizing_latents, |
| optional_keyframes, |
| v_start, |
| v_end, |
| h_start, |
| h_end, |
| height_scale_factor, |
| width_scale_factor, |
| ) |
|
|
| |
| tile_config = TileConfig( |
| tile_latents=tile_latents, |
| tile_guiding_latents=tile_guiding_latents, |
| tile_negative_index_latents=tile_negative_index_latents, |
| tile_keyframes=tile_keyframes, |
| keyframe_per_tile_indices=keyframe_per_tile_indices, |
| tile_normalizing_latents=tile_normalizing_latents, |
| tile_height=tile_height, |
| tile_width=tile_width, |
| v=v, |
| h=h, |
| vertical_tiles=vertical_tiles, |
| horizontal_tiles=horizontal_tiles, |
| first_seed=first_seed, |
| ) |
|
|
| sampling_config = SamplingConfig( |
| temporal_tile_size=temporal_tile_size, |
| temporal_overlap=temporal_overlap, |
| temporal_overlap_cond_strength=temporal_overlap_cond_strength, |
| cond_image_strength=cond_image_strength, |
| guiding_strength=guiding_strength, |
| adain_factor=adain_factor, |
| optional_negative_index=( |
| -1 * tile_negative_index_latents["samples"].shape[2] |
| if tile_negative_index_latents is not None |
| else -1 |
| ), |
| optional_negative_index_strength=optional_negative_index_strength, |
| optional_positive_conditionings=optional_positive_conditionings, |
| time_scale_factor=time_scale_factor, |
| width_scale_factor=width_scale_factor, |
| height_scale_factor=height_scale_factor, |
| per_tile_seed_offsets=per_tile_seed_offsets, |
| guiding_start_step=guiding_start_step, |
| guiding_end_step=guiding_end_step, |
| ) |
|
|
| model_config = ModelConfig( |
| model=model, |
| vae=vae, |
| noise=noise, |
| sampler=sampler, |
| sigmas=sigmas, |
| guider=guider, |
| ) |
|
|
| tile_out_latents = self._process_temporal_chunks( |
| tile_config, |
| sampling_config, |
| model_config, |
| ) |
|
|
| |
| if final_output is None: |
| out_temporal = tile_out_latents["samples"].shape[2] |
| final_output = torch.zeros( |
| batch, |
| channels, |
| out_temporal, |
| height, |
| width, |
| device=samples.device, |
| dtype=samples.dtype, |
| ) |
| weights = torch.zeros_like(final_output) |
|
|
| |
| tile_samples = tile_out_latents["samples"].to(final_output.device) |
| tile_weights = self._create_spatial_weights( |
| tile_samples.shape, |
| v, |
| h, |
| horizontal_tiles, |
| vertical_tiles, |
| spatial_overlap, |
| final_output.device, |
| final_output.dtype, |
| ) |
|
|
| |
| final_output[:, :, :, v_start:v_end, h_start:h_end] += ( |
| tile_samples * tile_weights |
| ) |
| weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights |
|
|
| |
| final_output = final_output / (weights + 1e-8) |
| out_latents = {"samples": final_output} |
|
|
| noise.seed = first_seed |
| return (out_latents,) |
|
|
|
|
| @comfy_node( |
| name="MultiPromptProvider", |
| ) |
| class MultiPromptProvider: |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "prompts": ( |
| "STRING", |
| { |
| "multiline": True, |
| "dynamicPrompts": True, |
| "tooltip": "Prompts to encode, one per line. Each prompt will be encoded separately. Each prompt will be used in one temporal_tile in LTXVLoopingSampler.", |
| }, |
| ), |
| "clip": ("CLIP", {"tooltip": "CLIP model to encode the prompts."}), |
| }, |
| } |
|
|
| RETURN_TYPES = ("CONDITIONING",) |
| RETURN_NAMES = ("conditionings",) |
|
|
| FUNCTION = "get_prompt_list" |
| CATEGORY = "prompt" |
|
|
| def get_prompt_list(self, prompts, clip): |
| prompt_list = prompts.split("|") |
| prompt_list = [prompt.strip() for prompt in prompt_list] |
| encoded_prompt_list = [ |
| clip.encode_from_tokens_scheduled(clip.tokenize(prompt)) |
| for prompt in prompt_list |
| ] |
| return (encoded_prompt_list,) |
|
|