|
|
import gradio as gr |
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
import random |
|
|
import os |
|
|
import yaml |
|
|
from pathlib import Path |
|
|
import imageio |
|
|
import tempfile |
|
|
from PIL import Image |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from inference import ( |
|
|
create_ltx_video_pipeline, |
|
|
create_latent_upsampler, |
|
|
load_image_to_tensor_with_resize_and_crop, |
|
|
seed_everething, |
|
|
calculate_padding, |
|
|
load_media_file |
|
|
) |
|
|
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline |
|
|
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy |
|
|
|
|
|
|
|
|
CONFIG_YAML = "configs/ltxv-13b-0.9.7-distilled.yaml" |
|
|
with open(CONFIG_YAML, "r") as f: |
|
|
CFG = yaml.safe_load(f) |
|
|
|
|
|
HF_REPO = "LTTEAM/VideoAI" |
|
|
MODELS_DIR = "downloaded_models" |
|
|
Path(MODELS_DIR).mkdir(exist_ok=True) |
|
|
|
|
|
print("Đang tải mô hình (nếu chưa có)…") |
|
|
ckpt = hf_hub_download(repo_id=HF_REPO, filename=CFG["checkpoint_path"], local_dir=MODELS_DIR) |
|
|
CFG["checkpoint_path"] = ckpt |
|
|
upscaler = hf_hub_download(repo_id=HF_REPO, filename=CFG["spatial_upscaler_model_path"], local_dir=MODELS_DIR) |
|
|
CFG["spatial_upscaler_model_path"] = upscaler |
|
|
|
|
|
|
|
|
print("Khởi tạo pipeline trên CPU…") |
|
|
pipeline = create_ltx_video_pipeline( |
|
|
ckpt_path=CFG["checkpoint_path"], |
|
|
precision=CFG["precision"], |
|
|
text_encoder_model_name_or_path=CFG["text_encoder_model_name_or_path"], |
|
|
sampler=CFG["sampler"], |
|
|
device="cpu", |
|
|
enhance_prompt=False, |
|
|
prompt_enhancer_image_caption_model_name_or_path=CFG["prompt_enhancer_image_caption_model_name_or_path"], |
|
|
prompt_enhancer_llm_model_name_or_path=CFG["prompt_enhancer_llm_model_name_or_path"], |
|
|
) |
|
|
print("Pipeline sẵn sàng.") |
|
|
print("Khởi tạo latent upsampler trên CPU…") |
|
|
upsampler = create_latent_upsampler(CFG["spatial_upscaler_model_path"], device="cpu") |
|
|
print("Upsampler sẵn sàng.") |
|
|
|
|
|
|
|
|
FPS = 30.0 |
|
|
MAX_FRAMES = 257 |
|
|
MIN_DIM = 256 |
|
|
FIXED_SIDE = 768 |
|
|
MAX_RES = CFG.get("max_resolution", 1280) |
|
|
|
|
|
def calc_new_dims(w, h): |
|
|
if w==0 or h==0: |
|
|
return FIXED_SIDE, FIXED_SIDE |
|
|
if w>=h: |
|
|
nh = FIXED_SIDE |
|
|
nw = round((nh*w/h)/32)*32 |
|
|
else: |
|
|
nw = FIXED_SIDE |
|
|
nh = round((nw*h/w)/32)*32 |
|
|
return ( |
|
|
int(max(MIN_DIM, min(nh, MAX_RES))), |
|
|
int(max(MIN_DIM, min(nw, MAX_RES))) |
|
|
) |
|
|
|
|
|
def get_duration(*args, duration_ui=0, **kwargs): |
|
|
return 75 if duration_ui > 7 else 60 |
|
|
|
|
|
|
|
|
def generate(prompt, neg_prompt, img_path, vid_path, |
|
|
height, width, mode, duration_ui, frames_to_use, |
|
|
seed, rand_seed, cfg_scale, improve_tex, device_choice, |
|
|
progress=gr.Progress(track_tqdm=True)): |
|
|
|
|
|
|
|
|
dev = "cuda" if device_choice=="GPU" and torch.cuda.is_available() else "cpu" |
|
|
print(f"Sử dụng thiết bị: {dev}") |
|
|
pipeline.to(dev) |
|
|
upsampler.to(dev) |
|
|
|
|
|
|
|
|
if rand_seed: |
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
seed_everething(int(seed)) |
|
|
|
|
|
|
|
|
tf = max(1, round(duration_ui * FPS)) |
|
|
n8 = round((tf-1)/8) |
|
|
n_frames = max(9, min(n8*8+1, MAX_FRAMES)) |
|
|
|
|
|
|
|
|
h, w = int(height), int(width) |
|
|
h_pad = ((h-1)//32+1)*32 |
|
|
w_pad = ((w-1)//32+1)*32 |
|
|
pad = calculate_padding(h, w, h_pad, w_pad) |
|
|
|
|
|
|
|
|
kwargs = { |
|
|
"prompt": prompt, |
|
|
"negative_prompt": neg_prompt, |
|
|
"height": h_pad, |
|
|
"width": w_pad, |
|
|
"num_frames": n_frames, |
|
|
"frame_rate": int(FPS), |
|
|
"generator": torch.Generator(device=dev).manual_seed(int(seed)), |
|
|
"output_type": "pt", |
|
|
"decode_timestep": CFG["decode_timestep"], |
|
|
"decode_noise_scale": CFG["decode_noise_scale"], |
|
|
"stochastic_sampling": CFG["stochastic_sampling"], |
|
|
"is_video": True, |
|
|
"vae_per_channel_normalize": True, |
|
|
"mixed_precision": CFG["precision"]=="mixed_precision", |
|
|
"offload_to_cpu": False, |
|
|
"enhance_prompt": False, |
|
|
} |
|
|
|
|
|
mode_stg = CFG.get("stg_mode","attention_values").lower() |
|
|
stg_map = { |
|
|
"stg_av": SkipLayerStrategy.AttentionValues, |
|
|
"attention_values": SkipLayerStrategy.AttentionValues, |
|
|
"stg_as": SkipLayerStrategy.AttentionSkip, |
|
|
"attention_skip": SkipLayerStrategy.AttentionSkip, |
|
|
"stg_r": SkipLayerStrategy.Residual, |
|
|
"residual": SkipLayerStrategy.Residual, |
|
|
"stg_t": SkipLayerStrategy.TransformerBlock, |
|
|
"transformer_block": SkipLayerStrategy.TransformerBlock, |
|
|
} |
|
|
kwargs["skip_layer_strategy"] = stg_map.get(mode_stg, SkipLayerStrategy.AttentionValues) |
|
|
|
|
|
|
|
|
if mode=="image-to-video" and img_path: |
|
|
t = load_image_to_tensor_with_resize_and_crop(img_path, h, w) |
|
|
t = torch.nn.functional.pad(t, pad) |
|
|
kwargs["conditioning_items"] = [ConditioningItem(t.to(dev), 0, 1.0)] |
|
|
elif mode=="video-to-video" and vid_path: |
|
|
mi = load_media_file(vid_path, h, w, int(frames_to_use), pad).to(dev) |
|
|
kwargs["media_items"] = mi |
|
|
|
|
|
|
|
|
if improve_tex: |
|
|
pipe_ms = LTXMultiScalePipeline(pipeline, upsampler) |
|
|
fp = CFG.get("first_pass",{}).copy() |
|
|
fp["guidance_scale"] = float(cfg_scale) |
|
|
fp.pop("num_inference_steps", None) |
|
|
sp = CFG.get("second_pass",{}).copy() |
|
|
sp["guidance_scale"] = float(cfg_scale) |
|
|
sp.pop("num_inference_steps", None) |
|
|
kwargs.update({ |
|
|
"downscale_factor": CFG["downscale_factor"], |
|
|
"first_pass": fp, |
|
|
"second_pass": sp |
|
|
}) |
|
|
images = pipe_ms(**kwargs).images |
|
|
else: |
|
|
fp0 = CFG.get("first_pass",{}) |
|
|
kwargs.update({ |
|
|
"timesteps": fp0.get("timesteps"), |
|
|
"guidance_scale": float(cfg_scale), |
|
|
"stg_scale": fp0.get("stg_scale"), |
|
|
"rescaling_scale": fp0.get("rescaling_scale"), |
|
|
"skip_block_list": fp0.get("skip_block_list") |
|
|
}) |
|
|
for k in ["first_pass","second_pass","downscale_factor","num_inference_steps"]: |
|
|
kwargs.pop(k, None) |
|
|
images = pipeline(**kwargs).images |
|
|
|
|
|
|
|
|
l, r, t_, b = pad |
|
|
sh = None if b==0 else -b |
|
|
sw = None if r==0 else -r |
|
|
vid_t = images[0][:,:,:n_frames, t_:sh, l:sw] |
|
|
arr = vid_t.permute(1,2,3,0).cpu().numpy() |
|
|
arr = (np.clip(arr,0,1)*255).astype(np.uint8) |
|
|
|
|
|
out_dir = tempfile.mkdtemp() |
|
|
out_path = os.path.join(out_dir, f"output_{random.randint(0,99999)}.mp4") |
|
|
with imageio.get_writer(out_path, fps=int(FPS), macro_block_size=1) as writer: |
|
|
for i in range(arr.shape[0]): |
|
|
progress(i/arr.shape[0], desc="Lưu video") |
|
|
writer.append_data(arr[i]) |
|
|
|
|
|
return out_path, seed |
|
|
|
|
|
|
|
|
css = """ |
|
|
#col-container { margin:0 auto; max-width:900px; } |
|
|
""" |
|
|
with gr.Blocks(css=css) as demo: |
|
|
gr.Markdown("## Ứng dụng LTX Video 0.9.7 Distilled") |
|
|
gr.Markdown( |
|
|
"[Mô hình trên HF](https://huggingface.co/LTTEAM/VideoAI) · " |
|
|
"[GitHub](https://github.com/Lightricks/LTX-Video)" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
device = gr.Radio(["CPU", "GPU"], label="Chạy trên thiết bị", value="CPU") |
|
|
|
|
|
with gr.Tab("Ảnh→Video"): |
|
|
img_in = gr.Image(label="Ảnh đầu vào", type="filepath", sources=["upload","clipboard","webcam"]) |
|
|
prompt1 = gr.Textbox(label="Mô tả", lines=2, value="Con sinh vật di chuyển") |
|
|
btn1 = gr.Button("Tạo từ ảnh") |
|
|
|
|
|
with gr.Tab("Văn bản→Video"): |
|
|
prompt2 = gr.Textbox(label="Mô tả", lines=2, value="Rồng bay trên lâu đài") |
|
|
btn2 = gr.Button("Tạo từ văn bản") |
|
|
|
|
|
with gr.Tab("Video→Video"): |
|
|
vid_in = gr.Video(label="Video đầu vào", sources=["upload","webcam"]) |
|
|
frames = gr.Slider(label="Số frame dùng", minimum=9, maximum=MAX_FRAMES, step=8, value=9) |
|
|
prompt3 = gr.Textbox(label="Mô tả", lines=2, value="Chuyển phong cách anime") |
|
|
btn3 = gr.Button("Tạo từ video") |
|
|
|
|
|
duration = gr.Slider(label="Thời lượng (giây)", minimum=0.3, maximum=8.5, step=0.1, value=2) |
|
|
improve = gr.Checkbox(label="Cải thiện chi tiết", value=True) |
|
|
|
|
|
with gr.Column(): |
|
|
out_vid = gr.Video(label="Kết quả", interactive=False) |
|
|
|
|
|
|
|
|
mode_state = gr.State("image-to-video") |
|
|
seed_state = gr.State(42) |
|
|
neg_state = gr.State("worst quality, inconsistent motion, blurry, jittery, distorted") |
|
|
cfg_state = gr.State(CFG["first_pass"]["guidance_scale"]) |
|
|
h_state = gr.State(512) |
|
|
w_state = gr.State(704) |
|
|
|
|
|
btn1.click(fn=generate, |
|
|
inputs=[prompt1, neg_state, img_in, gr.State(""), h_state, w_state, |
|
|
mode_state, duration, frames, seed_state, gr.State(True), |
|
|
cfg_state, improve, device], |
|
|
outputs=[out_vid, seed_state]) |
|
|
btn2.click(fn=generate, |
|
|
inputs=[prompt2, neg_state, gr.State(""), gr.State(""), h_state, w_state, |
|
|
mode_state, duration, frames, seed_state, gr.State(True), |
|
|
cfg_state, improve, device], |
|
|
outputs=[out_vid, seed_state]) |
|
|
btn3.click(fn=generate, |
|
|
inputs=[prompt3, neg_state, gr.State(""), vid_in, h_state, w_state, |
|
|
mode_state, duration, frames, seed_state, gr.State(True), |
|
|
cfg_state, improve, device], |
|
|
outputs=[out_vid, seed_state]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch(share=True) |
|
|
|