Spaces:
Configuration error
Configuration error
| import argparse | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import shutil | |
| from datetime import timedelta | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Union | |
| from PIL import Image | |
| from diffusers.utils import ( | |
| check_min_version, | |
| convert_unet_state_dict_to_peft, | |
| export_to_video, | |
| is_wandb_available, | |
| load_image, | |
| ) | |
| from torchvision.transforms import ToPILImage | |
| import torch | |
| from pathlib import PosixPath | |
| from utils.utils import load_model_from_config,load_segmented_safe_weights,control_weight_files | |
| from models.cogvideox_transformer_3d_control import Control3DModel,Controled_CogVideoXTransformer3DModel | |
| from models.pipeline_cogvideox_image2video import Controled_CogVideoXImageToVideoPipeline,Controled_Memory_CogVideoXImageToVideoPipeline | |
| from models.global_local_memory_module import global_local_memory | |
| import diffusers | |
| from diffusers import ( | |
| AutoencoderKLCogVideoX, | |
| CogVideoXDPMScheduler, | |
| #CogVideoXImageToVideoPipeline, | |
| CogVideoXTransformer3DModel, | |
| ) | |
| from lineart_extractor.annotator.lineart import LineartDetector | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.models.embeddings import get_3d_rotary_pos_embed | |
| from diffusers.optimization import get_scheduler | |
| from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid | |
| from diffusers.training_utils import cast_training_params, free_memory | |
| from diffusers.utils import ( | |
| load_image, | |
| ) | |
| from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card | |
| from diffusers.utils.torch_utils import is_compiled_module | |
| from torchvision.transforms.functional import center_crop, resize | |
| from torchvision.transforms import InterpolationMode | |
| import torchvision.transforms as TT | |
| import numpy as np | |
| from videoxl.model.builder import load_pretrained_model | |
| from videoxl.mm_utils import tokenizer_image_token, process_images,transform_input_id | |
| from videoxl.constants import IMAGE_TOKEN_INDEX,TOKEN_PERFRAME | |
| try: | |
| import decord | |
| except ImportError: | |
| raise ImportError( | |
| "The `decord` package is required for loading the video dataset. Install with `pip install decord`" | |
| ) | |
| decord.bridge.set_bridge("torch") | |
| from utils.autoreg_video_save_function import autoreg_video_save | |
| from decord import VideoReader, cpu | |
| from einops import rearrange | |
| import gc | |
| def _resize_for_rectangle_crop(arr,height,width,video_reshape_mode): | |
| image_size = height,width | |
| reshape_mode = video_reshape_mode | |
| if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: | |
| arr = resize( | |
| arr, | |
| size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], | |
| interpolation=InterpolationMode.BICUBIC, | |
| ) | |
| else: | |
| arr = resize( | |
| arr, | |
| size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], | |
| interpolation=InterpolationMode.BICUBIC, | |
| ) | |
| h, w = arr.shape[2], arr.shape[3] | |
| arr = arr.squeeze(0) | |
| delta_h = h - image_size[0] | |
| delta_w = w - image_size[1] | |
| if reshape_mode == "random" or reshape_mode == "none": | |
| top = np.random.randint(0, delta_h + 1) | |
| left = np.random.randint(0, delta_w + 1) | |
| elif reshape_mode == "center": | |
| top, left = delta_h // 2, delta_w // 2 | |
| else: | |
| raise NotImplementedError | |
| image_size = height, width | |
| arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) | |
| return arr | |
| def get_frame_length(frame_path): | |
| video_reader = decord.VideoReader(uri = frame_path.as_posix()) | |
| video_num_frames = len(video_reader) | |
| return video_num_frames | |
| def proccess_frame(frame_path,frames_start,frames_end): | |
| video_reader = decord.VideoReader(uri = frame_path.as_posix()) | |
| video_num_frames = len(video_reader) | |
| start_frame = frames_start | |
| end_frame = frames_end | |
| indices = list(range(start_frame, end_frame)) | |
| frames = video_reader.get_batch(indices) | |
| #frames = frames[start_frame: end_frame] | |
| selected_num_frames = frames.shape[0] | |
| print("selected_num_frames",selected_num_frames) | |
| # Choose first (4k + 1) frames as this is how many is required by the VAE | |
| remainder = (3 + (selected_num_frames % 4)) % 4 | |
| if remainder != 0: | |
| frames = frames[:-remainder] | |
| selected_num_frames = frames.shape[0] | |
| assert (selected_num_frames - 1) % 4 == 0 | |
| # Training transforms | |
| frames = frames.permute(0, 3, 1, 2) # [F, C, H, W] | |
| frames = _resize_for_rectangle_crop(frames,height=args.height,width=args.width,video_reshape_mode="center") | |
| final_frames = frames.contiguous() | |
| return final_frames | |
| def proccess_image(frames): | |
| # Training transforms | |
| frames = frames.unsqueeze(0).permute(0, 3, 1, 2) # [F, C, H, W] | |
| frames = _resize_for_rectangle_crop(frames,height=args.height,width=args.width,video_reshape_mode="center") | |
| final_frames = frames.contiguous() | |
| return final_frames | |
| def encode_sketch(video,pipe): | |
| video = video.to(pipe.vae.device, dtype=pipe.vae.dtype).unsqueeze(0) | |
| video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] | |
| latent_dist = pipe.vae.encode(video).latent_dist | |
| return latent_dist | |
| def process_sketch(sketch,linear_detector,pipe): | |
| sketch = sketch.to("cuda", dtype = torch.bfloat16) | |
| with torch.no_grad(): | |
| sketch = linear_detector(sketch,coarse=False) | |
| sketch=(sketch>0.78).float() | |
| sketch=1-sketch | |
| sketch=sketch.repeat(1,3,1,1) | |
| sketch = (sketch - 0.5) / 0.5 | |
| sketch=sketch.contiguous() | |
| sketch = sketch.to(pipe.vae.device, dtype=pipe.vae.dtype).unsqueeze(0) | |
| sketch = sketch.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] | |
| image = sketch[:, :, :1].clone() | |
| with torch.no_grad(): | |
| sketch = pipe.vae.encode(sketch).latent_dist | |
| sketches_first_frame=pipe.vae.encode(image).latent_dist | |
| sketch = sketch.sample() * pipe.vae.config.scaling_factor | |
| sketches_first_frame= sketches_first_frame.sample() * pipe.vae.config.scaling_factor | |
| sketch = sketch.permute(0, 2, 1, 3, 4) | |
| sketch = sketch.to(memory_format=torch.contiguous_format) | |
| sketches_first_frame = sketches_first_frame.permute(0, 2, 1, 3, 4) | |
| sketches_first_frame = sketches_first_frame.to(memory_format=torch.contiguous_format) | |
| return sketch,sketches_first_frame | |
| def process_sketch_image(sketch,linear_detector,pipe): | |
| sketch=torch.tensor(np.array(sketch)) | |
| sketch=proccess_image(sketch) | |
| sketch = sketch.to("cuda", dtype = torch.bfloat16) | |
| with torch.no_grad(): | |
| sketch = linear_detector(sketch,coarse=False) | |
| sketch=(sketch>0.78).float() | |
| sketch=1-sketch | |
| sketch=sketch.repeat(1,3,1,1) | |
| sketch = (sketch - 0.5) / 0.5 | |
| sketch=sketch.contiguous() | |
| sketch = sketch.to(pipe.vae.device, dtype=pipe.vae.dtype).unsqueeze(0) | |
| sketch = sketch.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] | |
| with torch.no_grad(): | |
| sketch = pipe.vae.encode(sketch).latent_dist | |
| sketch = sketch.sample() * pipe.vae.config.scaling_factor | |
| sketch = sketch.permute(0, 2, 1, 3, 4) | |
| sketch = sketch.to(memory_format=torch.contiguous_format) | |
| return sketch | |
| def log_validation( | |
| pipe, | |
| args, | |
| pipeline_args, | |
| device, | |
| use_glm=False, | |
| global_memory=None, | |
| local_memory=None, | |
| glm=None, | |
| past_latents=None, | |
| ): | |
| scheduler_args = {} | |
| idx = pipeline_args.pop("segment", None) | |
| video_key=pipeline_args.pop("video_key", None) | |
| clip_memory=False if idx==0 else True | |
| print("clip_memory",clip_memory) | |
| if "variance_type" in pipe.scheduler.config: | |
| variance_type = pipe.scheduler.config.variance_type | |
| if variance_type in ["learned", "learned_range"]: | |
| variance_type = "fixed_small" | |
| scheduler_args["variance_type"] = variance_type | |
| pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) | |
| pipe = pipe.to(device) | |
| generator = torch.Generator(device=device).manual_seed(args.seed) if args.seed else None | |
| videos = [] | |
| os.makedirs(os.path.join(args.output_dir,video_key),exist_ok=True) | |
| video_tensor_path=os.path.join(args.output_dir,video_key) | |
| print(video_tensor_path,"video_tensor_path") | |
| with torch.no_grad(): | |
| for _ in range(args.num_validation_videos): | |
| frames_output, past_latents = pipe(**pipeline_args, generator=generator, output_type="pt", | |
| num_inference_steps=50,use_glm=use_glm, | |
| global_memory=global_memory, | |
| local_memory=local_memory, | |
| glm=glm, | |
| video_tensor_path=video_tensor_path, | |
| past_latents=past_latents[:,-4:-2] if (past_latents is not None) else None , | |
| clip_memory=clip_memory | |
| ) | |
| pt_images=frames_output.frames[0] | |
| #TODO here we can choose if we need the first frame or not | |
| pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])]) | |
| image_np = VaeImageProcessor.pt_to_numpy(pt_images) | |
| image_pil = VaeImageProcessor.numpy_to_pil(image_np) | |
| videos.append(image_pil) | |
| phase_name = f"inference_{idx}" | |
| video_filenames = [] | |
| for i, video in enumerate(videos): | |
| final_output_dir=os.path.join(args.output_dir,video_key) | |
| os.makedirs(final_output_dir,exist_ok=True) | |
| filename = os.path.join(final_output_dir, f"{phase_name}_video.mp4") | |
| export_to_video(video, filename, fps=args.fps) | |
| video_filenames.append(filename) | |
| autoreg_video_save(base_path=final_output_dir,suffix="inference_{}_video.mp4",num_videos=idx+1) | |
| return videos[0][65] | |
| def save_segments(total_frames,segment_length,overlap): | |
| start_frame = 0 | |
| segments = [] | |
| while start_frame + segment_length <= total_frames: | |
| end_frame = start_frame + segment_length | |
| segments.append((start_frame, end_frame)) | |
| start_frame = end_frame - overlap | |
| return segments | |
| def main(args): | |
| os.makedirs(args.output_dir,exist_ok=True) | |
| load_dtype=torch.bfloat16 | |
| transformer =Controled_CogVideoXTransformer3DModel.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder="transformer", | |
| torch_dtype=load_dtype, | |
| ) | |
| control_config_path = "model_json/control_model_15_small.json" | |
| transformer_control_config = load_model_from_config(control_config_path) | |
| transformer_control = Control3DModel(**transformer_control_config) | |
| control_weight_files=[args.control_weght] | |
| transformer_control = load_segmented_safe_weights(transformer_control, control_weight_files) | |
| transformer_control = transformer_control.to(load_dtype) | |
| linear_detector=LineartDetector("cuda", dtype=torch.bfloat16) | |
| gen_kwargs = {"do_sample": True, "temperature": 1, "top_p": None, "num_beams": 1, "use_cache": True, "max_new_tokens": 2} | |
| # try: | |
| # video_tokenizer, video_model, clip_image_processor, _ = load_pretrained_model(args.llm_model_path, None, "llava_qwen", device_map="cuda",attn_implementation="flash_attention_2") | |
| # except: | |
| video_tokenizer, video_model, clip_image_processor, _ = load_pretrained_model(args.llm_model_path, None, "llava_qwen", device_map="cuda",attn_implementation="sdpa") | |
| video_model.config.beacon_ratio=[8] # you can delete this line to realize random compression of {2,4,8} ratio | |
| vllm_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nCan you describe the scene and color in anime?<|im_end|>\n<|im_start|>assistant\n" | |
| input_ids = tokenizer_image_token(vllm_prompt, video_tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(video_model.device) | |
| video_model.to( dtype=torch.bfloat16) | |
| glm=global_local_memory() | |
| glm_weight_files=[args.glm_weight] | |
| glm = load_segmented_safe_weights(glm,glm_weight_files) | |
| glm=glm.to(load_dtype) | |
| glm=glm.to("cuda") | |
| print("successful load glm") | |
| pipe = Controled_Memory_CogVideoXImageToVideoPipeline.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| torch_dtype=torch.bfloat16, | |
| transformer=transformer, | |
| transformer_control=transformer_control | |
| ).to("cuda") | |
| pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) | |
| del transformer,transformer_control | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| #pipe.enable_sequential_cpu_offload() | |
| if args.enable_slicing: | |
| pipe.vae.enable_slicing() | |
| if args.enable_tiling: | |
| pipe.vae.enable_tiling() | |
| #pipe = pipe.to("cuda") | |
| import json | |
| with open('test_json/long_testset.json',"r") as json_file: | |
| video_info=json.load(json_file) | |
| for video_key,value in video_info.items(): | |
| print('------------') | |
| print(video_key) | |
| validation_prompt=value['prompt'] | |
| video_path=PosixPath(value['video_path']) | |
| reference_image_path=str(value["reference_image"]) | |
| use_glm=False | |
| i=0 | |
| global_image=None | |
| frame_path = video_path | |
| video_num_frames=get_frame_length(frame_path) | |
| segments=save_segments(total_frames=video_num_frames,segment_length=args.max_num_frames,overlap=16) | |
| print(segments) | |
| '''''' | |
| past_latents=None | |
| for seg_idx,segment in enumerate(segments): | |
| print(seg_idx) | |
| print(segment) | |
| videos = proccess_frame(frame_path, frames_start=segment[0], frames_end=segment[1]) | |
| #print(segment) | |
| sketches,sketches_first_frame = process_sketch(videos,linear_detector,pipe) | |
| torch.cuda.empty_cache() | |
| print("sketches!!!",sketches.shape) | |
| validation_prompt = validation_prompt+" High quality, masterpiece, best quality, highres, ultra-detailed, fantastic." | |
| to_pil=ToPILImage() | |
| if global_image==None: | |
| print("------------------") | |
| print(reference_image_path) | |
| print('------------------') | |
| if reference_image_path != "0": | |
| image=Image.open(reference_image_path).convert("RGB") | |
| global_image=image | |
| sketches_first_frame = process_sketch_image(global_image,linear_detector,pipe) | |
| else: | |
| image=to_pil(videos[0]).convert("RGB") | |
| global_image=image | |
| sketches_first_frame = process_sketch_image(global_image,linear_detector,pipe) | |
| else: | |
| image=global_image | |
| pipeline_args = { | |
| "image": image, | |
| "prompt": validation_prompt, | |
| "guidance_scale": args.guidance_scale, | |
| "use_dynamic_cfg": args.use_dynamic_cfg, | |
| "height": args.height, | |
| "width": args.width, | |
| "sketches": sketches, | |
| "sketches_first_frame":sketches_first_frame, | |
| "num_frames":args.max_num_frames, | |
| "segment": seg_idx, | |
| "video_key":video_key | |
| } | |
| #load the video and process the video | |
| if use_glm: | |
| auto_path=os.path.join(args.output_dir,video_key,"autoreg_video_1.mp4") | |
| vr = VideoReader(auto_path, ctx=cpu(0)) | |
| total_frame_num = len(vr) | |
| if total_frame_num>650: | |
| max_frame=650 | |
| else: | |
| max_frame=total_frame_num | |
| uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frame, dtype=int) | |
| frame_idx = uniform_sampled_frames.tolist() | |
| frames = vr.get_batch(frame_idx).numpy() | |
| print(frames.shape) | |
| global_videos = clip_image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(video_model.device, dtype=torch.bfloat16) | |
| local_videos=global_videos[-20:,] | |
| beacon_skip_first = (input_ids == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[1].item() | |
| with torch.inference_mode(): | |
| num_tokens=TOKEN_PERFRAME *global_videos.shape[0] | |
| beacon_skip_last = beacon_skip_first + num_tokens | |
| video_model.generate(input_ids, images=[global_videos], modalities=["video"],beacon_skip_first=beacon_skip_first,beacon_skip_last=beacon_skip_last, **gen_kwargs) | |
| indices=[-9,-5,-1] | |
| global_memory=torch.cat([ | |
| torch.cat([rearrange(video_model.past_key_values[i][0], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0), | |
| torch.cat([rearrange(video_model.past_key_values[i][1], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0)] | |
| ,dim=0).unsqueeze(0) | |
| video_model.clear_past_key_values() | |
| video_model.memory.reset() | |
| print(global_memory.shape) | |
| torch.cuda.empty_cache() | |
| num_tokens=TOKEN_PERFRAME *local_videos.shape[0] | |
| beacon_skip_last = beacon_skip_first + num_tokens | |
| video_model.generate(input_ids, images=[local_videos], modalities=["video"],beacon_skip_first=beacon_skip_first,beacon_skip_last=beacon_skip_last, **gen_kwargs) | |
| indices=[-9,-5,-1] | |
| local_memory=torch.cat([ | |
| torch.cat([rearrange(video_model.past_key_values[i][0], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0), | |
| torch.cat([rearrange(video_model.past_key_values[i][1], 'b c h w -> b h (c w)') for i in indices],dim=0).unsqueeze(0)] | |
| ,dim=0).unsqueeze(0) | |
| video_model.clear_past_key_values() | |
| video_model.memory.reset() | |
| del global_videos,local_videos | |
| torch.cuda.empty_cache() | |
| print(local_memory.shape) | |
| else: | |
| global_memory=None | |
| local_memory=None | |
| last_image=log_validation( | |
| pipe=pipe, | |
| args=args, | |
| pipeline_args=pipeline_args, | |
| device="cuda", | |
| use_glm=use_glm, | |
| global_memory=global_memory, | |
| local_memory=local_memory, | |
| glm=glm, | |
| past_latents=past_latents | |
| ) | |
| torch.cuda.empty_cache() | |
| use_glm=True | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") | |
| parser.add_argument( | |
| "--guidance_scale", | |
| type=float, | |
| default=6, | |
| help="The guidance scale to use while sampling validation videos.", | |
| ) | |
| # Model information | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--llm_model_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--control_weght", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--glm_weight", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--use_dynamic_cfg", | |
| action="store_true", | |
| default=False, | |
| help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", | |
| ) | |
| parser.add_argument( | |
| "--cache_dir", | |
| type=str, | |
| default=None, | |
| help="The directory where the downloaded models and datasets will be stored.", | |
| ) | |
| parser.add_argument( | |
| "--num_validation_videos", | |
| type=int, | |
| default=1, | |
| help="Number of videos that should be generated during validation per `validation_prompt`.", | |
| ) | |
| # Training information | |
| parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") | |
| parser.add_argument( | |
| "--mixed_precision", | |
| type=str, | |
| default=None, | |
| choices=["no", "fp16", "bf16"], | |
| help=( | |
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | |
| " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" | |
| " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="cogvideox-i2v-lora", | |
| help="The output directory where the model predictions and checkpoints will be written.", | |
| ) | |
| parser.add_argument( | |
| "--height", | |
| type=int, | |
| default=480, | |
| help="All input videos are resized to this height.", | |
| ) | |
| parser.add_argument( | |
| "--width", | |
| type=int, | |
| default=720, | |
| help="All input videos are resized to this width.", | |
| ) | |
| parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") | |
| parser.add_argument( | |
| "--max_num_frames", type=int, default=81, help="All input videos will be truncated to these many frames." | |
| ) | |
| parser.add_argument( | |
| "--enable_slicing", | |
| action="store_true", | |
| default=False, | |
| help="Whether or not to use VAE slicing for saving memory.", | |
| ) | |
| parser.add_argument( | |
| "--enable_tiling", | |
| action="store_true", | |
| default=False, | |
| help="Whether or not to use VAE tiling for saving memory.", | |
| ) | |
| parser.add_argument( | |
| "--allow_tf32", | |
| action="store_true", | |
| help=( | |
| "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" | |
| " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" | |
| ), | |
| ) | |
| return parser.parse_args() | |
| if __name__=="__main__": | |
| args = get_args() | |
| main(args) | |