|
|
import gradio as gr
|
|
|
|
|
|
from gradio_toggle import Toggle
|
|
|
import argparse
|
|
|
import json
|
|
|
import os
|
|
|
import random
|
|
|
from datetime import datetime
|
|
|
from pathlib import Path
|
|
|
from diffusers.utils import logging
|
|
|
|
|
|
import imageio
|
|
|
import numpy as np
|
|
|
import safetensors.torch
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from PIL import Image
|
|
|
from transformers import T5EncoderModel, T5Tokenizer
|
|
|
import tempfile
|
|
|
from ltx_video.models.autoencoders.causal_video_autoencoder import (
|
|
|
CausalVideoAutoencoder,
|
|
|
)
|
|
|
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
|
|
from ltx_video.models.transformers.transformer3d import Transformer3DModel
|
|
|
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
|
|
|
from ltx_video.schedulers.rf import RectifiedFlowScheduler
|
|
|
from ltx_video.utils.conditioning_method import ConditioningMethod
|
|
|
from torchao.quantization import quantize_, int8_weight_only
|
|
|
|
|
|
MAX_HEIGHT = 720
|
|
|
MAX_WIDTH = 1280
|
|
|
MAX_NUM_FRAMES = 257
|
|
|
|
|
|
|
|
|
def load_vae(vae_dir, int8=False):
|
|
|
vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
|
|
|
vae_config_path = vae_dir / "config.json"
|
|
|
with open(vae_config_path, "r") as f:
|
|
|
vae_config = json.load(f)
|
|
|
vae = CausalVideoAutoencoder.from_config(vae_config)
|
|
|
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
|
|
vae.load_state_dict(vae_state_dict)
|
|
|
|
|
|
vae = vae.to('cpu')
|
|
|
if int8:
|
|
|
print("vae - quantization = true")
|
|
|
quantize_(vae, int8_weight_only())
|
|
|
return vae
|
|
|
|
|
|
|
|
|
def load_unet(unet_dir, int8=False):
|
|
|
unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
|
|
|
unet_config_path = unet_dir / "config.json"
|
|
|
transformer_config = Transformer3DModel.load_config(unet_config_path)
|
|
|
transformer = Transformer3DModel.from_config(transformer_config)
|
|
|
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
|
|
transformer.load_state_dict(unet_state_dict, strict=True)
|
|
|
|
|
|
transformer = transformer.to('cpu')
|
|
|
if int8:
|
|
|
print("unet - quantization = true")
|
|
|
quantize_(transformer, int8_weight_only())
|
|
|
return transformer
|
|
|
|
|
|
|
|
|
def load_scheduler(scheduler_dir):
|
|
|
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
|
|
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
|
|
|
return RectifiedFlowScheduler.from_config(scheduler_config)
|
|
|
|
|
|
|
|
|
def load_image_to_tensor_with_resize_and_crop(image_path, target_height=512, target_width=768):
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
input_width, input_height = image.size
|
|
|
aspect_ratio_target = target_width / target_height
|
|
|
aspect_ratio_frame = input_width / input_height
|
|
|
if aspect_ratio_frame > aspect_ratio_target:
|
|
|
new_width = int(input_height * aspect_ratio_target)
|
|
|
new_height = input_height
|
|
|
x_start = (input_width - new_width) // 2
|
|
|
y_start = 0
|
|
|
else:
|
|
|
new_width = input_width
|
|
|
new_height = int(input_width / aspect_ratio_target)
|
|
|
x_start = 0
|
|
|
y_start = (input_height - new_height) // 2
|
|
|
|
|
|
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
|
|
|
image = image.resize((target_width, target_height))
|
|
|
frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
|
|
|
frame_tensor = (frame_tensor / 127.5) - 1.0
|
|
|
|
|
|
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
|
|
|
|
|
|
|
|
def calculate_padding(
|
|
|
source_height: int, source_width: int, target_height: int, target_width: int
|
|
|
) -> tuple[int, int, int, int]:
|
|
|
|
|
|
|
|
|
pad_height = target_height - source_height
|
|
|
pad_width = target_width - source_width
|
|
|
|
|
|
|
|
|
pad_top = pad_height // 2
|
|
|
pad_bottom = pad_height - pad_top
|
|
|
pad_left = pad_width // 2
|
|
|
pad_right = pad_width - pad_left
|
|
|
|
|
|
|
|
|
|
|
|
padding = (pad_left, pad_right, pad_top, pad_bottom)
|
|
|
return padding
|
|
|
|
|
|
|
|
|
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
|
|
|
|
|
|
clean_text = "".join(
|
|
|
char.lower() for char in text if char.isalpha() or char.isspace()
|
|
|
)
|
|
|
|
|
|
|
|
|
words = clean_text.split()
|
|
|
|
|
|
|
|
|
result = []
|
|
|
current_length = 0
|
|
|
|
|
|
for word in words:
|
|
|
|
|
|
new_length = current_length + len(word)
|
|
|
|
|
|
if new_length <= max_len:
|
|
|
result.append(word)
|
|
|
current_length += len(word)
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
return "-".join(result)
|
|
|
|
|
|
|
|
|
|
|
|
def get_unique_filename(
|
|
|
base: str,
|
|
|
ext: str,
|
|
|
prompt: str,
|
|
|
seed: int,
|
|
|
resolution: tuple[int, int, int],
|
|
|
dir: Path,
|
|
|
endswith=None,
|
|
|
index_range=1000,
|
|
|
) -> Path:
|
|
|
base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
|
|
|
for i in range(index_range):
|
|
|
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
|
|
|
if not os.path.exists(filename):
|
|
|
return filename
|
|
|
raise FileExistsError(
|
|
|
f"Could not find a unique filename after {index_range} attempts."
|
|
|
)
|
|
|
|
|
|
|
|
|
def seed_everething(seed: int):
|
|
|
random.seed(seed)
|
|
|
np.random.seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
|
|
|
def main(
|
|
|
img2vid_image="",
|
|
|
prompt="",
|
|
|
txt2vid_analytics_toggle=False,
|
|
|
negative_prompt="",
|
|
|
frame_rate=25,
|
|
|
seed=0,
|
|
|
num_inference_steps=30,
|
|
|
guidance_scale=3,
|
|
|
height=512,
|
|
|
width=768,
|
|
|
num_frames=121,
|
|
|
progress=gr.Progress(),
|
|
|
):
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
args = {
|
|
|
"ckpt_dir": "Lightricks/LTX-Video",
|
|
|
"num_inference_steps": num_inference_steps,
|
|
|
"guidance_scale": guidance_scale,
|
|
|
"height": height,
|
|
|
"width": width,
|
|
|
"num_frames": num_frames,
|
|
|
"frame_rate": frame_rate,
|
|
|
"prompt": prompt,
|
|
|
"negative_prompt": negative_prompt,
|
|
|
"seed": 0,
|
|
|
"output_path": os.path.join(tempfile.gettempdir(), "gradio"),
|
|
|
"num_images_per_prompt": 1,
|
|
|
"input_image_path": img2vid_image,
|
|
|
"input_video_path": "",
|
|
|
"bfloat16": True,
|
|
|
"disable_load_needed_only": False
|
|
|
}
|
|
|
logger.warning(f"Running generation with arguments: {args}")
|
|
|
|
|
|
seed_everething(args['seed'])
|
|
|
|
|
|
output_dir = (
|
|
|
Path(args['output_path'])
|
|
|
if args['output_path']
|
|
|
else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
|
|
|
)
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
if args['input_image_path']:
|
|
|
media_items_prepad = load_image_to_tensor_with_resize_and_crop(
|
|
|
args['input_image_path'], args['height'], args['width']
|
|
|
)
|
|
|
else:
|
|
|
media_items_prepad = None
|
|
|
|
|
|
height = args['height'] if args['height'] else media_items_prepad.shape[-2]
|
|
|
width = args['width'] if args['width'] else media_items_prepad.shape[-1]
|
|
|
num_frames = args['num_frames']
|
|
|
|
|
|
if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES:
|
|
|
logger.warning(
|
|
|
f"Input resolution or number of frames {height}x{width}x{num_frames} is too big, it is suggested to use the resolution below {MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES}."
|
|
|
)
|
|
|
|
|
|
|
|
|
height_padded = ((height - 1) // 32 + 1) * 32
|
|
|
width_padded = ((width - 1) // 32 + 1) * 32
|
|
|
num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
|
|
|
|
|
|
padding = calculate_padding(height, width, height_padded, width_padded)
|
|
|
|
|
|
logger.warning(
|
|
|
f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
|
|
|
)
|
|
|
|
|
|
if media_items_prepad is not None:
|
|
|
media_items = F.pad(
|
|
|
media_items_prepad, padding, mode="constant", value=-1
|
|
|
)
|
|
|
else:
|
|
|
media_items = None
|
|
|
|
|
|
|
|
|
vae = load_vae(Path(args['ckpt_dir']) / "vae", txt2vid_analytics_toggle)
|
|
|
unet = load_unet(Path(args['ckpt_dir']) / "unet", txt2vid_analytics_toggle)
|
|
|
scheduler = load_scheduler(Path(args['ckpt_dir']) / "scheduler")
|
|
|
patchifier = SymmetricPatchifier(patch_size=1)
|
|
|
text_encoder = T5EncoderModel.from_pretrained(
|
|
|
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
|
|
).to('cpu')
|
|
|
|
|
|
tokenizer = T5Tokenizer.from_pretrained(
|
|
|
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
|
|
)
|
|
|
|
|
|
|
|
|
submodel_dict = {
|
|
|
"transformer": unet,
|
|
|
"patchifier": patchifier,
|
|
|
"text_encoder": text_encoder,
|
|
|
"tokenizer": tokenizer,
|
|
|
"scheduler": scheduler,
|
|
|
"vae": vae,
|
|
|
}
|
|
|
|
|
|
pipeline = LTXVideoPipeline(**submodel_dict)
|
|
|
pipeline = pipeline.to('cpu')
|
|
|
|
|
|
|
|
|
sample = {
|
|
|
"prompt": args['prompt'],
|
|
|
"prompt_attention_mask": None,
|
|
|
"negative_prompt": args['negative_prompt'],
|
|
|
"negative_prompt_attention_mask": None,
|
|
|
"media_items": media_items,
|
|
|
}
|
|
|
|
|
|
generator = torch.Generator(device="cpu").manual_seed(args['seed'])
|
|
|
|
|
|
images = pipeline(
|
|
|
num_inference_steps=args['num_inference_steps'],
|
|
|
num_images_per_prompt=args['num_images_per_prompt'],
|
|
|
guidance_scale=args['guidance_scale'],
|
|
|
generator=generator,
|
|
|
output_type="pt",
|
|
|
callback_on_step_end=None,
|
|
|
height=height_padded,
|
|
|
width=width_padded,
|
|
|
num_frames=num_frames_padded,
|
|
|
frame_rate=args['frame_rate'],
|
|
|
**sample,
|
|
|
is_video=True,
|
|
|
vae_per_channel_normalize=True,
|
|
|
conditioning_method=(
|
|
|
ConditioningMethod.FIRST_FRAME
|
|
|
if media_items is not None
|
|
|
else ConditioningMethod.UNCONDITIONAL
|
|
|
),
|
|
|
mixed_precision=not args['bfloat16'],
|
|
|
load_needed_only=not args['disable_load_needed_only']
|
|
|
).images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|