|
|
import random |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import torch |
|
|
import torch.distributed.checkpoint.stateful |
|
|
from diffusers.video_processor import VideoProcessor |
|
|
|
|
|
import finetrainers.functional as FF |
|
|
from finetrainers.logging import get_logger |
|
|
from finetrainers.processors import CannyProcessor, CopyProcessor |
|
|
|
|
|
from .config import ControlType, FrameConditioningType |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class IterableControlDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): |
|
|
def __init__( |
|
|
self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.dataset = dataset |
|
|
self.control_type = control_type |
|
|
|
|
|
self.control_processors = [] |
|
|
if control_type == ControlType.CANNY: |
|
|
self.control_processors.append( |
|
|
CannyProcessor( |
|
|
output_names=["control_output"], input_names={"image": "input", "video": "input"}, device=device |
|
|
) |
|
|
) |
|
|
elif control_type == ControlType.NONE: |
|
|
self.control_processors.append( |
|
|
CopyProcessor(output_names=["control_output"], input_names={"image": "input", "video": "input"}) |
|
|
) |
|
|
|
|
|
logger.info("Initialized IterableControlDataset") |
|
|
|
|
|
def __iter__(self): |
|
|
logger.info("Starting IterableControlDataset") |
|
|
for data in iter(self.dataset): |
|
|
control_augmented_data = self._run_control_processors(data) |
|
|
yield control_augmented_data |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
self.dataset.load_state_dict(state_dict) |
|
|
|
|
|
def state_dict(self): |
|
|
return self.dataset.state_dict() |
|
|
|
|
|
def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
if "control_image" in data: |
|
|
if "image" in data: |
|
|
data["control_image"] = FF.resize_to_nearest_bucket_image( |
|
|
data["control_image"], [data["image"].shape[-2:]], resize_mode="bicubic" |
|
|
) |
|
|
if "video" in data: |
|
|
batch_size, num_frames, num_channels, height, width = data["video"].shape |
|
|
data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video( |
|
|
data["control_video"], [[num_frames, height, width]], resize_mode="bicubic" |
|
|
) |
|
|
if _first_frame_only: |
|
|
msg = ( |
|
|
"The number of frames in the control video is less than the minimum bucket size " |
|
|
"specified. The first frame is being used as a single frame video. This " |
|
|
"message is logged at the first occurence and for every 128th occurence " |
|
|
"after that." |
|
|
) |
|
|
logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128) |
|
|
data["control_video"] = data["control_video"][0] |
|
|
return data |
|
|
|
|
|
if "control_video" in data: |
|
|
if "image" in data: |
|
|
data["control_image"] = FF.resize_to_nearest_bucket_image( |
|
|
data["control_video"][0], [data["image"].shape[-2:]], resize_mode="bicubic" |
|
|
) |
|
|
if "video" in data: |
|
|
batch_size, num_frames, num_channels, height, width = data["video"].shape |
|
|
data["control_video"], _first_frame_only = FF.resize_to_nearest_bucket_video( |
|
|
data["control_video"], [[num_frames, height, width]], resize_mode="bicubic" |
|
|
) |
|
|
if _first_frame_only: |
|
|
msg = ( |
|
|
"The number of frames in the control video is less than the minimum bucket size " |
|
|
"specified. The first frame is being used as a single frame video. This " |
|
|
"message is logged at the first occurence and for every 128th occurence " |
|
|
"after that." |
|
|
) |
|
|
logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE_CONTROL", msg, frequency=128) |
|
|
data["control_video"] = data["control_video"][0] |
|
|
return data |
|
|
|
|
|
if self.control_type == ControlType.CUSTOM: |
|
|
return data |
|
|
|
|
|
shallow_copy_data = dict(data.items()) |
|
|
is_image_control = "image" in shallow_copy_data |
|
|
is_video_control = "video" in shallow_copy_data |
|
|
if (is_image_control + is_video_control) != 1: |
|
|
raise ValueError("Exactly one of 'image' or 'video' should be present in the data.") |
|
|
for processor in self.control_processors: |
|
|
result = processor(**shallow_copy_data) |
|
|
result_keys = set(result.keys()) |
|
|
repeat_keys = result_keys.intersection(shallow_copy_data.keys()) |
|
|
if repeat_keys: |
|
|
logger.warning( |
|
|
f"Processor {processor.__class__.__name__} returned keys that already exist in " |
|
|
f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " |
|
|
f"be intended. Please rename the keys in the processor to avoid conflicts." |
|
|
) |
|
|
shallow_copy_data.update(result) |
|
|
if "control_output" in shallow_copy_data: |
|
|
|
|
|
control_output = shallow_copy_data.pop("control_output") |
|
|
|
|
|
control_output = FF.normalize(control_output, min=-1.0, max=1.0) |
|
|
key = "control_image" if is_image_control else "control_video" |
|
|
shallow_copy_data[key] = control_output |
|
|
return shallow_copy_data |
|
|
|
|
|
|
|
|
class ValidationControlDataset(torch.utils.data.IterableDataset): |
|
|
def __init__( |
|
|
self, dataset: torch.utils.data.IterableDataset, control_type: str, device: Optional[torch.device] = None |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.dataset = dataset |
|
|
self.control_type = control_type |
|
|
self.device = device |
|
|
self._video_processor = VideoProcessor() |
|
|
|
|
|
self.control_processors = [] |
|
|
if control_type == ControlType.CANNY: |
|
|
self.control_processors.append( |
|
|
CannyProcessor(["control_output"], input_names={"image": "input", "video": "input"}, device=device) |
|
|
) |
|
|
elif control_type == ControlType.NONE: |
|
|
self.control_processors.append( |
|
|
CopyProcessor(["control_output"], input_names={"image": "input", "video": "input"}) |
|
|
) |
|
|
|
|
|
logger.info("Initialized ValidationControlDataset") |
|
|
|
|
|
def __iter__(self): |
|
|
logger.info("Starting ValidationControlDataset") |
|
|
for data in iter(self.dataset): |
|
|
control_augmented_data = self._run_control_processors(data) |
|
|
yield control_augmented_data |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
self.dataset.load_state_dict(state_dict) |
|
|
|
|
|
def state_dict(self): |
|
|
return self.dataset.state_dict() |
|
|
|
|
|
def _run_control_processors(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
if self.control_type == ControlType.CUSTOM: |
|
|
return data |
|
|
|
|
|
if "control_image" in data or "control_video" in data: |
|
|
return data |
|
|
shallow_copy_data = dict(data.items()) |
|
|
is_image_control = "image" in shallow_copy_data |
|
|
is_video_control = "video" in shallow_copy_data |
|
|
if (is_image_control + is_video_control) != 1: |
|
|
raise ValueError("Exactly one of 'image' or 'video' should be present in the data.") |
|
|
for processor in self.control_processors: |
|
|
result = processor(**shallow_copy_data) |
|
|
result_keys = set(result.keys()) |
|
|
repeat_keys = result_keys.intersection(shallow_copy_data.keys()) |
|
|
if repeat_keys: |
|
|
logger.warning( |
|
|
f"Processor {processor.__class__.__name__} returned keys that already exist in " |
|
|
f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " |
|
|
f"be intended. Please rename the keys in the processor to avoid conflicts." |
|
|
) |
|
|
shallow_copy_data.update(result) |
|
|
if "control_output" in shallow_copy_data: |
|
|
|
|
|
control_output = shallow_copy_data.pop("control_output") |
|
|
if torch.is_tensor(control_output): |
|
|
|
|
|
control_output = FF.normalize(control_output, min=-1.0, max=1.0) |
|
|
ndim = control_output.ndim |
|
|
assert 3 <= ndim <= 5, "Control output should be at least ndim=3 and less than or equal to ndim=5" |
|
|
if ndim == 5: |
|
|
control_output = self._video_processor.postprocess_video(control_output, output_type="pil") |
|
|
else: |
|
|
if ndim == 3: |
|
|
control_output = control_output.unsqueeze(0) |
|
|
control_output = self._video_processor.postprocess(control_output, output_type="pil")[0] |
|
|
key = "control_image" if is_image_control else "control_video" |
|
|
shallow_copy_data[key] = control_output |
|
|
return shallow_copy_data |
|
|
|
|
|
|
|
|
|
|
|
def apply_frame_conditioning_on_latents( |
|
|
latents: torch.Tensor, |
|
|
expected_num_frames: int, |
|
|
channel_dim: int, |
|
|
frame_dim: int, |
|
|
frame_conditioning_type: FrameConditioningType, |
|
|
frame_conditioning_index: Optional[int] = None, |
|
|
concatenate_mask: bool = False, |
|
|
) -> torch.Tensor: |
|
|
num_frames = latents.size(frame_dim) |
|
|
mask = torch.zeros_like(latents) |
|
|
|
|
|
if frame_conditioning_type == FrameConditioningType.INDEX: |
|
|
frame_index = min(frame_conditioning_index, num_frames - 1) |
|
|
indexing = [slice(None)] * latents.ndim |
|
|
indexing[frame_dim] = frame_index |
|
|
mask[tuple(indexing)] = 1 |
|
|
latents = latents * mask |
|
|
|
|
|
elif frame_conditioning_type == FrameConditioningType.PREFIX: |
|
|
frame_index = random.randint(1, num_frames) |
|
|
indexing = [slice(None)] * latents.ndim |
|
|
indexing[frame_dim] = slice(0, frame_index) |
|
|
mask[tuple(indexing)] = 1 |
|
|
latents = latents * mask |
|
|
|
|
|
elif frame_conditioning_type == FrameConditioningType.RANDOM: |
|
|
|
|
|
num_frames_to_keep = random.randint(1, num_frames) |
|
|
frame_indices = random.sample(range(num_frames), num_frames_to_keep) |
|
|
indexing = [slice(None)] * latents.ndim |
|
|
indexing[frame_dim] = frame_indices |
|
|
mask[tuple(indexing)] = 1 |
|
|
latents = latents * mask |
|
|
|
|
|
elif frame_conditioning_type == FrameConditioningType.FIRST_AND_LAST: |
|
|
indexing = [slice(None)] * latents.ndim |
|
|
indexing[frame_dim] = 0 |
|
|
mask[tuple(indexing)] = 1 |
|
|
indexing[frame_dim] = num_frames - 1 |
|
|
mask[tuple(indexing)] = 1 |
|
|
latents = latents * mask |
|
|
|
|
|
elif frame_conditioning_type == FrameConditioningType.FULL: |
|
|
indexing = [slice(None)] * latents.ndim |
|
|
indexing[frame_dim] = slice(0, num_frames) |
|
|
mask[tuple(indexing)] = 1 |
|
|
|
|
|
if latents.size(frame_dim) >= expected_num_frames: |
|
|
slicing = [slice(None)] * latents.ndim |
|
|
slicing[frame_dim] = slice(expected_num_frames) |
|
|
latents = latents[tuple(slicing)] |
|
|
mask = mask[tuple(slicing)] |
|
|
else: |
|
|
pad_size = expected_num_frames - num_frames |
|
|
pad_shape = list(latents.shape) |
|
|
pad_shape[frame_dim] = pad_size |
|
|
padding = latents.new_zeros(pad_shape) |
|
|
latents = torch.cat([latents, padding], dim=frame_dim) |
|
|
mask = torch.cat([mask, padding], dim=frame_dim) |
|
|
|
|
|
if concatenate_mask: |
|
|
slicing = [slice(None)] * latents.ndim |
|
|
slicing[channel_dim] = 0 |
|
|
latents = torch.cat([latents, mask], dim=channel_dim) |
|
|
|
|
|
return latents |
|
|
|