Spaces:
Configuration error
Configuration error
| import os | |
| import numpy as np | |
| from typing import List, Union | |
| import PIL | |
| import torch | |
| import torch.utils.data | |
| import torch.utils.checkpoint | |
| from diffusers.pipeline_utils import DiffusionPipeline | |
| from tqdm.auto import tqdm | |
| from video_diffusion.common.image_util import make_grid, annotate_image | |
| from video_diffusion.common.image_util import save_gif_mp4_folder_type | |
| import cv2 | |
| class SampleLogger: | |
| def __init__( | |
| self, | |
| editing_prompts: List[str], | |
| clip_length: int, | |
| logdir: str, | |
| subdir: str = "sample", | |
| num_samples_per_prompt: int = 1, | |
| sample_seeds: List[int] = None, | |
| num_inference_steps: int = 20, | |
| guidance_scale: float = 7, | |
| strength: float = None, | |
| annotate: bool = False, | |
| annotate_size: int = 15, | |
| make_grid: bool = True, | |
| grid_column_size: int = 2, | |
| layout_mask_dir: str = None, # New parameter for the layout mask directory | |
| layouts_masks_orders: List[str]=None, | |
| stride: int = 1, | |
| n_sample_frame: int = 8, | |
| start_sample_frame: int = None, | |
| sampling_rate: int = 1, | |
| **args | |
| ) -> None: | |
| self.editing_prompts = editing_prompts | |
| self.clip_length = clip_length | |
| self.guidance_scale = guidance_scale | |
| self.num_inference_steps = num_inference_steps | |
| self.strength = strength | |
| if sample_seeds is None: | |
| max_num_samples_per_prompt = int(1e5) | |
| if num_samples_per_prompt > max_num_samples_per_prompt: | |
| raise ValueError | |
| sample_seeds = torch.randint(0, max_num_samples_per_prompt, (num_samples_per_prompt,)) | |
| sample_seeds = sorted(sample_seeds.numpy().tolist()) | |
| self.sample_seeds = sample_seeds | |
| self.logdir = os.path.join(logdir, subdir) | |
| os.makedirs(self.logdir, exist_ok=True) | |
| self.annotate = annotate | |
| self.annotate_size = annotate_size | |
| self.make_grid = make_grid | |
| self.grid_column_size = grid_column_size | |
| self.layout_mask_dir = layout_mask_dir # Initialize layout_mask_dir | |
| self.layout_mask_orders = layouts_masks_orders | |
| self.stride = stride | |
| self.n_sample_frame = n_sample_frame | |
| self.start_sample_frame = start_sample_frame | |
| self.sampling_rate = sampling_rate | |
| def _read_mask(self, mask_path, index: int, dest_size=(64, 64)): | |
| mask_path = os.path.join(mask_path, f"{index:05d}.png") | |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| mask = (mask > 0).astype(np.uint8) | |
| mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) | |
| mask = mask[np.newaxis, ...] | |
| return mask | |
| def get_frame_indices(self, index): | |
| if self.start_sample_frame is not None: | |
| frame_start = self.start_sample_frame + self.stride * index | |
| else: | |
| frame_start = self.stride * index | |
| return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame)) | |
| def read_layout_and_merge_masks(self, index): | |
| layouts_all, masks_all = [],[] | |
| for idx,layout_mask_order_per in enumerate(self.layout_mask_orders): | |
| layout_ = [] | |
| for layout_name in layout_mask_order_per: # Loop over prompts | |
| frame_indices = self.get_frame_indices(index % self.clip_length) | |
| layout_mask_dir = os.path.join(self.layout_mask_dir, layout_name) | |
| mask = [self._read_mask(layout_mask_dir, i) for i in frame_indices] | |
| masks = np.stack(mask) | |
| layout_.append(masks) | |
| layout_ = np.stack(layout_) | |
| merged_masks = [] | |
| for i in range(int(self.n_sample_frame)): | |
| merged_mask_frame = np.sum(layout_[:, i, :, :], axis=0) | |
| merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) | |
| merged_masks.append(merged_mask_frame) | |
| masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w") | |
| masks = torch.from_numpy(masks).half() | |
| layouts = rearrange(layout_, "s f c h w -> f s c h w") | |
| layouts = torch.from_numpy(layouts).half() | |
| layouts_all.append(layouts) | |
| masks_all.append(mask) | |
| return masks_all, layouts_all | |
| def log_sample_images( | |
| self, pipeline: DiffusionPipeline, | |
| device: torch.device, step: int, | |
| image: Union[torch.FloatTensor, PIL.Image.Image] = None, | |
| masks: Union[torch.FloatTensor, PIL.Image.Image] = None, | |
| layouts : Union[torch.FloatTensor, PIL.Image.Image] = None, | |
| latents: torch.FloatTensor = None, | |
| control: torch.FloatTensor = None, | |
| controlnet_conditioning_scale = None, | |
| negative_prompt: Union[str, List[str]] = None, | |
| blending_percentage = None, | |
| trajs = None, | |
| flatten_res = None, | |
| source_prompt = None, | |
| inject_step = None, | |
| old_qk = None, | |
| use_pnp = None, | |
| cluster_inversion_feature = None, | |
| vis_cross_attn = None, | |
| attn_inversion_dict = None, | |
| ): | |
| torch.cuda.empty_cache() | |
| samples_all = [] | |
| attention_all = [] | |
| # handle input image | |
| if image is not None: | |
| input_pil_images = pipeline.numpy_to_pil(tensor_to_numpy(image))[0] | |
| samples_all.append(input_pil_images) | |
| # samples_all.append([ | |
| # annotate_image(image, "input sequence", font_size=self.annotate_size) for image in input_pil_images | |
| # ]) | |
| #masks_all, layouts_all = self.read_layout_and_merge_masks() | |
| #for idx, (prompt, masks, layouts) in enumerate(tqdm(zip(self.editing_prompts, masks_all, layouts_all), desc="Generating sample images")): | |
| for idx, prompt in enumerate(tqdm(self.editing_prompts, desc="Generating sample images")): | |
| for seed in self.sample_seeds: | |
| generator = torch.Generator(device=device) | |
| generator.manual_seed(seed) | |
| sequence_return = pipeline( | |
| prompt=prompt, | |
| image=image, # torch.Size([8, 3, 512, 512]) | |
| latent_mask=masks, | |
| layouts = layouts, | |
| strength=self.strength, | |
| generator=generator, | |
| num_inference_steps=self.num_inference_steps, | |
| clip_length=self.clip_length, | |
| guidance_scale=self.guidance_scale, | |
| num_images_per_prompt=1, | |
| # used in null inversion | |
| control = control, | |
| controlnet_conditioning_scale = controlnet_conditioning_scale, | |
| latents = latents, | |
| #uncond_embeddings_list = uncond_embeddings_list, | |
| blending_percentage = blending_percentage, | |
| logdir = self.logdir, | |
| trajs = trajs, | |
| flatten_res = flatten_res, | |
| negative_prompt=negative_prompt, | |
| source_prompt=source_prompt, | |
| inject_step=inject_step, | |
| old_qk=old_qk, | |
| use_pnp=use_pnp, | |
| cluster_inversion_feature= cluster_inversion_feature, | |
| vis_cross_attn = vis_cross_attn, | |
| attn_inversion_dict=attn_inversion_dict, | |
| ) | |
| sequence = sequence_return.images[0] | |
| torch.cuda.empty_cache() | |
| if self.annotate: | |
| images = [ | |
| annotate_image(image, prompt, font_size=self.annotate_size) for image in sequence | |
| ] | |
| else: | |
| images = sequence | |
| if self.make_grid: | |
| samples_all.append(images) | |
| save_path = os.path.join(self.logdir, f"step_{step}_{idx}_{seed}.gif") | |
| save_gif_mp4_folder_type(images, save_path) | |
| if self.make_grid: | |
| samples_all = [make_grid(images, cols=int(np.ceil(np.sqrt(len(samples_all))))) for images in zip(*samples_all)] | |
| save_path = os.path.join(self.logdir, f"step_{step}.gif") | |
| save_gif_mp4_folder_type(samples_all, save_path) | |
| return samples_all | |
| from einops import rearrange | |
| def tensor_to_numpy(image, b=1): | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | |
| image = image.cpu().float().numpy() | |
| image = rearrange(image, "(b f) c h w -> b f h w c", b=b) | |
| return image |