Wan_Backup / custom_nodes /ComfyUI-LTXVideo /looping_sampler.py
Eji-Sensei14's picture
Upload folder using huggingface_hub
c6535db verified
import copy
from dataclasses import dataclass
import comfy
import torch
from .easy_samplers import LTXVBaseSampler, LTXVExtendSampler, LTXVInContextSampler
from .latents import LTXVDilateLatent, LTXVSelectLatents
from .nodes_registry import comfy_node
@dataclass
class TileConfig:
"""Configuration for spatial tile processing."""
tile_latents: dict
tile_guiding_latents: dict
tile_negative_index_latents: dict
tile_keyframes: torch.Tensor
keyframe_per_tile_indices: list
tile_normalizing_latents: dict
tile_height: int
tile_width: int
v: int
h: int
vertical_tiles: int
horizontal_tiles: int
first_seed: int
@dataclass
class SamplingConfig:
"""Configuration for sampling parameters."""
temporal_tile_size: int
temporal_overlap: int
temporal_overlap_cond_strength: float
cond_image_strength: float
guiding_strength: float
adain_factor: float
optional_negative_index: int
optional_negative_index_strength: float
optional_positive_conditionings: list
time_scale_factor: int
width_scale_factor: int
height_scale_factor: int
per_tile_seed_offsets: list
guiding_start_step: int
guiding_end_step: int
@dataclass
class ModelConfig:
"""Configuration for model components."""
model: object
vae: object
noise: object
sampler: object
sigmas: object
guider: object
@comfy_node(
name="LTXVLoopingSampler",
)
class LTXVLoopingSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model to use."}),
"vae": ("VAE", {"tooltip": "The VAE to use."}),
"noise": ("NOISE", {"tooltip": "The noise to use."}),
"sampler": ("SAMPLER", {"tooltip": "The sampler to use."}),
"sigmas": ("SIGMAS", {"tooltip": "The sigmas to use."}),
"guider": (
"GUIDER",
{"tooltip": "The guider to use, must be a STGGuiderAdvanced."},
),
"latents": (
"LATENT",
{
"tooltip": "The latents to use for creating the long video, they can be guiding latents or empty latents when no guidance is used."
},
),
"temporal_tile_size": (
"INT",
{
"default": 80,
"min": 24,
"max": 1000,
"step": 8,
"tooltip": "The size of the temporal tile to use for the sampling, in pixel frames, in addition to the overlapping region.",
},
),
"temporal_overlap": (
"INT",
{
"default": 24,
"min": 16,
"max": 80,
"step": 8,
"tooltip": "The overlap between the temporal tiles, in pixel frames.",
},
),
"guiding_strength": (
"FLOAT",
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "The strength of the conditioning on guiding latents, when optional_guiding_latents are provided.",
},
),
"temporal_overlap_cond_strength": (
"FLOAT",
{
"default": 0.5,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "The strength of the conditioning on the latents from the previous temporal tile.",
},
),
"cond_image_strength": (
"FLOAT",
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "The strength of the conditioning on the optional_cond_images.",
},
),
"horizontal_tiles": (
"INT",
{
"default": 1,
"min": 1,
"max": 6,
"tooltip": "Number of horizontal spatial tiles.",
},
),
"vertical_tiles": (
"INT",
{
"default": 1,
"min": 1,
"max": 6,
"tooltip": "Number of vertical spatial tiles.",
},
),
"spatial_overlap": (
"INT",
{
"default": 1,
"min": 1,
"max": 8,
"tooltip": "Overlap between spatial tiles.",
},
),
},
"optional": {
"optional_cond_images": (
"IMAGE",
{
"tooltip": "The image to use for conditioning the first frame in the video (i2v setup). If not provided, the first frame will be unconditioned (t2v setup). The image will be resized to the size of the first frame."
},
),
"optional_guiding_latents": (
"LATENT",
{
"tooltip": "The latents to use for guiding the sampling, typically with an IC-LoRA."
},
),
"adain_factor": (
"FLOAT",
{
"default": 0.0,
"min": 0.0,
"max": 1.0,
"step": 0.01,
"tooltip": "The strength of the AdaIn operation used to fix the statistics of each new generated temporal tile, to prevent accumulated oversaturation.",
},
),
"optional_positive_conditionings": (
"CONDITIONING",
{
"tooltip": "Optional way to provide changing positive prompts, one per temporal tile, using the MultiPromptProvider node."
},
),
"optional_negative_index_latents": (
"LATENT",
{
"tooltip": "Special optional latents to condition on a negative index before each new temporal tile as a way to provide long term context during video generation."
},
),
"guiding_start_step": (
"INT",
{
"default": 0,
"min": 0,
"max": 1000,
"tooltip": "The step at which the guiding latents start to be used.",
},
),
"guiding_end_step": (
"INT",
{
"default": 1000,
"min": 0,
"max": 1000,
"tooltip": "The step at which the guiding latents stop to be used.",
},
),
"optional_cond_image_indices": (
"STRING",
{
"default": "0",
"tooltip": "The indices of the keyframes to use for the sampling, separated by commas. The indices are counted from the end of the video.",
},
),
"optional_normalizing_latents": (
"LATENT",
{
"tooltip": "The latents to use for normalizing the output latents, they will be used to normalize the output latents to the same statistics as the input latents."
},
),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("denoised_output",)
FUNCTION = "sample"
CATEGORY = "sampling"
def _extract_latent_spatial_tile(self, latent_dict, v_start, v_end, h_start, h_end):
"""Extract spatial tile from a latent dictionary."""
if latent_dict is None:
return None
tile_samples = latent_dict["samples"][:, :, :, v_start:v_end, h_start:h_end]
if "noise_mask" in latent_dict and latent_dict["noise_mask"] is not None:
tile_masks = latent_dict["noise_mask"][
:, :, :, v_start:v_end, h_start:h_end
]
return {"samples": tile_samples, "noise_mask": tile_masks}
else:
return {"samples": tile_samples}
def _extract_spatial_tile(
self,
latents,
optional_guiding_latents,
optional_negative_index_latents,
optional_normalizing_latents,
optional_keyframes,
v_start,
v_end,
h_start,
h_end,
height_scale_factor,
width_scale_factor,
):
"""Extract spatial tiles from all inputs for a given spatial region."""
# Extract spatial tile from latents
tile_latents = self._extract_latent_spatial_tile(
latents, v_start, v_end, h_start, h_end
)
# Extract spatial tile from guiding latents if provided
tile_guiding_latents = self._extract_latent_spatial_tile(
optional_guiding_latents, v_start, v_end, h_start, h_end
)
# Extract spatial tile from negative index latents if provided
tile_negative_index_latents = self._extract_latent_spatial_tile(
optional_negative_index_latents, v_start, v_end, h_start, h_end
)
# Extract spatial tile from normalizing latents if provided
tile_normalizing_latents = self._extract_latent_spatial_tile(
optional_normalizing_latents, v_start, v_end, h_start, h_end
)
if optional_keyframes is not None:
# Scale coordinates for image
img_h_start = v_start * height_scale_factor
img_h_end = v_end * height_scale_factor
img_w_start = h_start * width_scale_factor
img_w_end = h_end * width_scale_factor
tile_keyframes = optional_keyframes[
:, img_h_start:img_h_end, img_w_start:img_w_end, :
]
else:
tile_keyframes = None
return (
tile_latents,
tile_guiding_latents,
tile_negative_index_latents,
tile_keyframes,
tile_normalizing_latents,
)
def _process_temporal_chunks(
self,
tile_config: TileConfig,
sampling_config: SamplingConfig,
model_config: ModelConfig,
):
"""Process all temporal chunks for a single spatial tile."""
chunk_index = 0
tile_out_latents = None
first_tile_out_latents = None
for i_temporal_tile, (start_index, end_index) in enumerate(
zip(
range(
0,
tile_config.tile_latents["samples"].shape[2]
+ sampling_config.temporal_tile_size
- sampling_config.temporal_overlap,
sampling_config.temporal_tile_size
- sampling_config.temporal_overlap,
),
range(
sampling_config.temporal_tile_size,
tile_config.tile_latents["samples"].shape[2]
+ sampling_config.temporal_tile_size
- sampling_config.temporal_overlap,
sampling_config.temporal_tile_size
- sampling_config.temporal_overlap,
),
)
):
if tile_config.tile_guiding_latents is not None:
guiding_latent_chunk = LTXVSelectLatents().select_latents(
tile_config.tile_guiding_latents,
start_index,
min(
end_index - 1,
tile_config.tile_guiding_latents["samples"].shape[2] - 1,
),
)[0]
else:
guiding_latent_chunk = None
latent_chunk = LTXVSelectLatents().select_latents(
tile_config.tile_latents,
start_index,
min(end_index - 1, tile_config.tile_latents["samples"].shape[2] - 1),
)[0]
print(
"Processing temporal chunk at index",
start_index,
"to",
min(end_index - 1, tile_config.tile_latents["samples"].shape[2] - 1),
)
seed_offset = self._get_per_tile_value(
sampling_config.per_tile_seed_offsets, i_temporal_tile
)
if tile_config.tile_normalizing_latents is not None:
normalizing_latent_chunk = LTXVSelectLatents().select_latents(
tile_config.tile_normalizing_latents,
start_index,
min(
end_index - 1,
tile_config.tile_normalizing_latents["samples"].shape[2] - 1,
),
)[0]
normalize_per_frame = True
print(
"Normalizing latents provided, normalizing per frame and channel with factor",
sampling_config.adain_factor,
)
else:
normalizing_latent_chunk = first_tile_out_latents
normalize_per_frame = False
print(
"No normalizing latents provided, normalizing per channel using first chunk with factor",
sampling_config.adain_factor,
)
model_config.noise.seed = self._calculate_tile_seed(
tile_config.first_seed,
start_index,
tile_config.vertical_tiles,
tile_config.horizontal_tiles,
tile_config.v,
tile_config.h,
seed_offset,
)
# Handle optional positive conditionings
new_guider = self._prepare_guider_for_chunk(
model_config.guider,
sampling_config.optional_positive_conditionings,
chunk_index,
)
this_chunk_keyframe_indices = [
in_tile_index
for (tile_index, in_tile_index) in tile_config.keyframe_per_tile_indices
if tile_index == i_temporal_tile
]
if this_chunk_keyframe_indices and tile_config.tile_keyframes is not None:
this_chunk_keyframes = torch.cat(
[
tile_config.tile_keyframes[i_keyframe].unsqueeze(0)
for (i_keyframe, (tile_index, in_tile_index)) in enumerate(
tile_config.keyframe_per_tile_indices
)
if tile_index == i_temporal_tile
]
)
print(
f"Chunk {i_temporal_tile} keyframe indices: {this_chunk_keyframe_indices}"
)
else:
this_chunk_keyframes = None
print(f"Chunk {i_temporal_tile} has no keyframes")
this_chunk_keyframe_indices = ",".join(
[str(i) for i in this_chunk_keyframe_indices]
)
if start_index == 0:
if tile_config.tile_guiding_latents is not None:
tile_out_latents = LTXVInContextSampler().sample(
vae=model_config.vae,
guider=new_guider,
sampler=model_config.sampler,
sigmas=model_config.sigmas,
noise=model_config.noise,
guiding_latents=guiding_latent_chunk,
optional_cond_images=this_chunk_keyframes,
optional_cond_indices=this_chunk_keyframe_indices,
num_frames=-1,
optional_negative_index_latents=tile_config.tile_negative_index_latents,
optional_negative_index=sampling_config.optional_negative_index,
optional_negative_index_strength=sampling_config.optional_negative_index_strength,
optional_initialization_latents=latent_chunk,
cond_image_strength=sampling_config.cond_image_strength,
guiding_strength=sampling_config.guiding_strength,
guiding_start_step=sampling_config.guiding_start_step,
guiding_end_step=sampling_config.guiding_end_step,
)[0]
else:
tile_out_latents = LTXVBaseSampler().sample(
model=model_config.model,
vae=model_config.vae,
noise=model_config.noise,
sampler=model_config.sampler,
sigmas=model_config.sigmas,
guider=new_guider,
num_frames=(
min(
sampling_config.temporal_tile_size,
tile_config.tile_latents["samples"].shape[2],
)
- 1
)
* sampling_config.time_scale_factor
+ 1,
width=tile_config.tile_width
* sampling_config.width_scale_factor,
height=tile_config.tile_height
* sampling_config.height_scale_factor,
optional_cond_images=this_chunk_keyframes,
optional_cond_indices=this_chunk_keyframe_indices,
crop="center",
crf=30,
strength=sampling_config.cond_image_strength,
optional_negative_index_latents=tile_config.tile_negative_index_latents,
optional_negative_index=sampling_config.optional_negative_index,
optional_negative_index_strength=sampling_config.optional_negative_index_strength,
optional_initialization_latents=latent_chunk,
guiding_start_step=sampling_config.guiding_start_step,
guiding_end_step=sampling_config.guiding_end_step,
)[0]
first_tile_out_latents = copy.deepcopy(tile_out_latents)
else:
tile_out_latents = LTXVExtendSampler().sample(
model=model_config.model,
vae=model_config.vae,
sampler=model_config.sampler,
sigmas=model_config.sigmas,
noise=model_config.noise,
latents=tile_out_latents,
num_new_frames=(
latent_chunk["samples"].shape[2]
- sampling_config.temporal_overlap
)
* sampling_config.time_scale_factor,
frame_overlap=sampling_config.temporal_overlap
* sampling_config.time_scale_factor,
guider=new_guider,
strength=sampling_config.temporal_overlap_cond_strength,
guiding_strength=sampling_config.guiding_strength,
cond_image_strength=sampling_config.cond_image_strength,
optional_guiding_latents=guiding_latent_chunk,
optional_cond_images=this_chunk_keyframes,
optional_cond_indices=this_chunk_keyframe_indices,
optional_reference_latents=normalizing_latent_chunk,
normalize_per_frame=normalize_per_frame,
adain_factor=sampling_config.adain_factor,
optional_negative_index_latents=tile_config.tile_negative_index_latents,
optional_negative_index=sampling_config.optional_negative_index,
optional_negative_index_strength=sampling_config.optional_negative_index_strength,
optional_initialization_latents=latent_chunk,
guiding_start_step=sampling_config.guiding_start_step,
guiding_end_step=sampling_config.guiding_end_step,
)[0]
chunk_index += 1
return tile_out_latents
def _create_spatial_weights(
self,
tile_shape,
v,
h,
horizontal_tiles,
vertical_tiles,
spatial_overlap,
device,
dtype,
):
"""Create blending weights for spatial tiles."""
tile_weights = torch.ones(tile_shape, device=device, dtype=dtype)
# Apply horizontal blending weights
if h > 0: # Left overlap
h_blend = torch.linspace(0, 1, spatial_overlap, device=device, dtype=dtype)
tile_weights[:, :, :, :, :spatial_overlap] *= h_blend.view(1, 1, 1, 1, -1)
if h < horizontal_tiles - 1: # Right overlap
h_blend = torch.linspace(1, 0, spatial_overlap, device=device, dtype=dtype)
tile_weights[:, :, :, :, -spatial_overlap:] *= h_blend.view(1, 1, 1, 1, -1)
# Apply vertical blending weights
if v > 0: # Top overlap
v_blend = torch.linspace(0, 1, spatial_overlap, device=device, dtype=dtype)
tile_weights[:, :, :, :spatial_overlap, :] *= v_blend.view(1, 1, 1, -1, 1)
if v < vertical_tiles - 1: # Bottom overlap
v_blend = torch.linspace(1, 0, spatial_overlap, device=device, dtype=dtype)
tile_weights[:, :, :, -spatial_overlap:, :] *= v_blend.view(1, 1, 1, -1, 1)
return tile_weights
def _calculate_tile_seed(
self,
first_seed,
start_index,
vertical_tiles,
horizontal_tiles,
v,
h,
seed_offset,
):
"""Calculate the seed value for a specific temporal and spatial tile."""
return (
first_seed
+ start_index * (vertical_tiles * horizontal_tiles)
+ v * horizontal_tiles
+ h
+ seed_offset
)
def _get_per_tile_value(self, value_list, tile_index):
"""Get a value from a per-tile configuration list, using the last value if the list is shorter."""
return value_list[min(tile_index, len(value_list) - 1)]
def _parse_comma_separated_string(
self, config_string, default_value, converter_func, optional_total_size=None
):
"""Parse a comma-separated configuration string into a list with type conversion."""
if config_string == "":
config_string = default_value
values = [converter_func(item.strip()) for item in config_string.split(",")]
def handle_negative_index(value):
if value < 0:
return value + optional_total_size
return value
if optional_total_size is not None:
values = [handle_negative_index(value) for value in values]
return values
def _prepare_guider_for_chunk(
self, guider, optional_positive_conditionings, chunk_index
):
"""Prepare the guider for a specific chunk, handling optional positive conditionings."""
if optional_positive_conditionings is not None:
new_guider = copy.copy(guider)
positive, negative = guider.raw_conds
# Use the conditioning at chunk_index, or the last one if we've run out
conditioning_index = min(
chunk_index, len(optional_positive_conditionings) - 1
)
new_guider.set_conds(
optional_positive_conditionings[conditioning_index],
negative,
)
new_guider.raw_conds = (
optional_positive_conditionings[conditioning_index],
negative,
)
return new_guider
else:
return guider
def _calculate_keyframe_per_tile_indices(
self, keyframe_indices, temporal_tile_size, temporal_overlap, num_frames
):
"""
Calculate which temporal tile each keyframe falls into.
Returns a list of tuples (temporal_tile_index, in_tile_index) for each keyframe.
- temporal_tile_index: which temporal tile the keyframe falls in
- in_tile_index: index within that specific temporal tile
First tile: frames [0, temporal_tile_size - 8] (size = temporal_tile_size - 8 + 1)
Subsequent tiles follow the pattern from _process_temporal_chunks:
- Tile n starts at: n * (temporal_tile_size - temporal_overlap)
- Tile n ends at: temporal_tile_size + n * (temporal_tile_size - temporal_overlap) - 1
For subsequent tiles, keyframes with in_tile_index < temporal_overlap
don't count as falling in that tile.
"""
result = []
for keyframe_index in keyframe_indices:
if keyframe_index >= num_frames:
print(
f"Keyframe index {keyframe_index} is greater than num_frames {num_frames}, skipping"
)
continue
# First tile (tile 0): covers frames [0, temporal_tile_size - 8]
if keyframe_index < temporal_tile_size - 7:
result.append((0, keyframe_index))
continue
# Find which subsequent tile this keyframe could fall into
# Tile n starts at: n * (temporal_tile_size - temporal_overlap) - 7
# Tile n ends at: temporal_tile_size + n * (temporal_tile_size - temporal_overlap) - 1 - 7
tile_step = temporal_tile_size - temporal_overlap
tile_index = 1
while True:
tile_start = tile_index * tile_step - 7
tile_end = temporal_tile_size + tile_index * tile_step - 1 - 7
print(
f"Tile {tile_index} starts at {tile_start} and ends at {tile_end}"
)
# Check if keyframe falls within this tile's range
if keyframe_index <= tile_end:
in_tile_index = (
keyframe_index - tile_start - 7
) # this extra -7 is needed because the first latent is re-interpreted inside the temporal tile as 1 pixel frame
# For tiles > 0, if in_tile_index < temporal_overlap,
# the keyframe doesn't "fall" in this tile - assign to previous tile
if in_tile_index < temporal_overlap:
tile_index -= 1
if tile_index == 0:
# Previous tile is the first tile
in_tile_index = keyframe_index
else:
# Previous tile starts at (tile_index) * tile_step
prev_start = tile_start - tile_step
in_tile_index = (
keyframe_index - prev_start - 7
) # same reason as above
result.append((tile_index, in_tile_index))
break
tile_index += 1
return result
def sample(
self,
model,
vae,
noise,
sampler,
sigmas,
guider,
latents,
guiding_strength,
adain_factor,
temporal_tile_size,
temporal_overlap,
temporal_overlap_cond_strength,
horizontal_tiles,
vertical_tiles,
spatial_overlap,
optional_cond_images=None,
cond_image_strength=1.0,
optional_guiding_latents=None,
optional_negative_index_latents=None,
optional_negative_index_strength=1.0, # hidden interface
optional_positive_conditionings=None,
guiding_start_step=0,
guiding_end_step=1000,
optional_cond_image_indices="0",
optional_normalizing_latents=None,
per_tile_seed_offsets="0", # hidden interface
):
# Get dimensions and prepare for spatial tiling
samples = latents["samples"]
if (
isinstance(samples, comfy.nested_tensor.NestedTensor)
and len(samples.tensors) == 2
):
raise ValueError(
"LoopingSampler currently does not support Audio Visual latents. please only use video latents."
)
batch, channels, frames, height, width = samples.shape
time_scale_factor, width_scale_factor, height_scale_factor = (
vae.downscale_index_formula
)
temporal_tile_size = temporal_tile_size // time_scale_factor
temporal_overlap = temporal_overlap // time_scale_factor
first_seed = noise.seed
per_tile_seed_offsets = self._parse_comma_separated_string(
per_tile_seed_offsets, "0", int
)
keyframe_indices = self._parse_comma_separated_string(
optional_cond_image_indices,
"0",
int,
optional_total_size=frames * time_scale_factor - 7,
)
keyframe_per_tile_indices = self._calculate_keyframe_per_tile_indices(
keyframe_indices,
temporal_tile_size * time_scale_factor,
temporal_overlap * time_scale_factor,
frames * time_scale_factor - 7,
)
print(f"Keyframe per tile indices: {keyframe_per_tile_indices}")
if optional_cond_images is not None:
optional_keyframes = (
comfy.utils.common_upscale(
optional_cond_images.movedim(-1, 1),
width * width_scale_factor,
height * height_scale_factor,
"bilinear",
crop="center",
)
.movedim(1, -1)
.clamp(0, 1)
)
else:
optional_keyframes = None
if optional_guiding_latents is not None:
guide = optional_guiding_latents["samples"]
assert (
samples.shape[2] == guide.shape[2]
), "The number of frames in the latents and optional_guiding_latents must be the same"
assert (
samples.shape[3] % guide.shape[3] == 0
and samples.shape[4] % guide.shape[4] == 0
), "The ratio of the height and width of the latents and optional_guiding_latents must be an integer"
grid_size_h = samples.shape[3] // guide.shape[3]
grid_size_w = samples.shape[4] // guide.shape[4]
optional_guiding_latents = LTXVDilateLatent().dilate_latent(
optional_guiding_latents, grid_size_w, grid_size_h
)[0]
# Calculate tile sizes with overlap
base_tile_height = (
height + (vertical_tiles - 1) * spatial_overlap
) // vertical_tiles
base_tile_width = (
width + (horizontal_tiles - 1) * spatial_overlap
) // horizontal_tiles
# Output tensors will be initialized after first tile to get correct temporal dimension
final_output = None
weights = None
# Process each spatial tile
for v in range(vertical_tiles):
for h in range(horizontal_tiles):
# Calculate tile boundaries
h_start = h * (base_tile_width - spatial_overlap)
v_start = v * (base_tile_height - spatial_overlap)
# Adjust end positions for edge tiles
h_end = (
min(h_start + base_tile_width, width)
if h < horizontal_tiles - 1
else width
)
v_end = (
min(v_start + base_tile_height, height)
if v < vertical_tiles - 1
else height
)
# Calculate actual tile dimensions
tile_height = v_end - v_start
tile_width = h_end - h_start
print(f"Processing spatial tile at row {v}, col {h}:")
print(f" Position: ({v_start}:{v_end}, {h_start}:{h_end})")
print(f" Size: {tile_height}x{tile_width}")
# Extract spatial tiles from all inputs
(
tile_latents,
tile_guiding_latents,
tile_negative_index_latents,
tile_keyframes,
tile_normalizing_latents,
) = self._extract_spatial_tile(
latents,
optional_guiding_latents,
optional_negative_index_latents,
optional_normalizing_latents,
optional_keyframes,
v_start,
v_end,
h_start,
h_end,
height_scale_factor,
width_scale_factor,
)
# Process all temporal chunks for this spatial tile
tile_config = TileConfig(
tile_latents=tile_latents,
tile_guiding_latents=tile_guiding_latents,
tile_negative_index_latents=tile_negative_index_latents,
tile_keyframes=tile_keyframes,
keyframe_per_tile_indices=keyframe_per_tile_indices,
tile_normalizing_latents=tile_normalizing_latents,
tile_height=tile_height,
tile_width=tile_width,
v=v,
h=h,
vertical_tiles=vertical_tiles,
horizontal_tiles=horizontal_tiles,
first_seed=first_seed,
)
sampling_config = SamplingConfig(
temporal_tile_size=temporal_tile_size,
temporal_overlap=temporal_overlap,
temporal_overlap_cond_strength=temporal_overlap_cond_strength,
cond_image_strength=cond_image_strength,
guiding_strength=guiding_strength,
adain_factor=adain_factor,
optional_negative_index=(
-1 * tile_negative_index_latents["samples"].shape[2]
if tile_negative_index_latents is not None
else -1
),
optional_negative_index_strength=optional_negative_index_strength,
optional_positive_conditionings=optional_positive_conditionings,
time_scale_factor=time_scale_factor,
width_scale_factor=width_scale_factor,
height_scale_factor=height_scale_factor,
per_tile_seed_offsets=per_tile_seed_offsets,
guiding_start_step=guiding_start_step,
guiding_end_step=guiding_end_step,
)
model_config = ModelConfig(
model=model,
vae=vae,
noise=noise,
sampler=sampler,
sigmas=sigmas,
guider=guider,
)
tile_out_latents = self._process_temporal_chunks(
tile_config,
sampling_config,
model_config,
)
# Initialize output tensors on first tile (to get correct temporal dimension)
if final_output is None:
out_temporal = tile_out_latents["samples"].shape[2]
final_output = torch.zeros(
batch,
channels,
out_temporal,
height,
width,
device=samples.device,
dtype=samples.dtype,
)
weights = torch.zeros_like(final_output)
# Move tile samples to output device and create weight mask
tile_samples = tile_out_latents["samples"].to(final_output.device)
tile_weights = self._create_spatial_weights(
tile_samples.shape,
v,
h,
horizontal_tiles,
vertical_tiles,
spatial_overlap,
final_output.device,
final_output.dtype,
)
# Add weighted tile to final output
final_output[:, :, :, v_start:v_end, h_start:h_end] += (
tile_samples * tile_weights
)
weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights
# Normalize by weights
final_output = final_output / (weights + 1e-8)
out_latents = {"samples": final_output}
noise.seed = first_seed
return (out_latents,)
@comfy_node(
name="MultiPromptProvider",
)
class MultiPromptProvider:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompts": (
"STRING",
{
"multiline": True,
"dynamicPrompts": True,
"tooltip": "Prompts to encode, one per line. Each prompt will be encoded separately. Each prompt will be used in one temporal_tile in LTXVLoopingSampler.",
},
),
"clip": ("CLIP", {"tooltip": "CLIP model to encode the prompts."}),
},
}
RETURN_TYPES = ("CONDITIONING",)
RETURN_NAMES = ("conditionings",)
FUNCTION = "get_prompt_list"
CATEGORY = "prompt"
def get_prompt_list(self, prompts, clip):
prompt_list = prompts.split("|")
prompt_list = [prompt.strip() for prompt in prompt_list]
encoded_prompt_list = [
clip.encode_from_tokens_scheduled(clip.tokenize(prompt))
for prompt in prompt_list
]
return (encoded_prompt_list,)