Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| import os | |
| import yaml | |
| import json | |
| from pathlib import Path | |
| import random | |
| # Importações de Hugging Face | |
| from huggingface_hub import snapshot_download, HfFolder | |
| from transformers import T5EncoderModel, T5TokenizerFast | |
| from diffusers import LTXLatentUpsamplePipeline, AutoModel | |
| from diffusers.models import AutoencoderKLLTXVideo, LTXVideoTransformer3DModel | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| # Nossa pipeline customizada e utilitários | |
| from pipeline_ltx_condition_control import LTXConditionPipeline, LTXVideoCondition | |
| from diffusers.utils import export_to_video | |
| from PIL import Image, ImageOps | |
| import imageio | |
| # --- Configuração de Logging e Avisos --- | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) # Correto: UserWarning é uma classe | |
| warnings.filterwarnings("ignore", category=FutureWarning) # Correto: FutureWarning é uma classe | |
| warnings.filterwarnings("ignore", message=".*") | |
| # --- CARREGAMENTO DIRETO DOS MODELOS (SEM CLASSE) --- | |
| print("=== [Inicialização da Aplicação] ===") | |
| # 1. Carregar Configuração do Arquivo YAML | |
| CONFIG_PATH = Path("ltxv-13b-0.9.8-dev-fp8.yaml") | |
| if not CONFIG_PATH.exists(): | |
| raise FileNotFoundError(f"Arquivo de configuração '{CONFIG_PATH}' não encontrado.") | |
| with open(CONFIG_PATH, "r") as f: | |
| CONFIG = yaml.safe_load(f) | |
| print(f"Configuração carregada de: {CONFIG_PATH}") | |
| print(json.dumps(CONFIG, indent=2)) | |
| # Parâmetros Globais | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.bfloat16 | |
| base_repo="Lightricks/LTX-Video" | |
| checkpoint_path="ltxv-13b-0.9.8-dev-fp8.safetensors" | |
| upscaler_repo="Lightricks/ltxv-spatial-upscaler-0.9.7" | |
| FPS = 24 | |
| CACHE_DIR = os.environ.get("HF_HOME") | |
| DEPS_DIR = Path("/data") | |
| LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video" | |
| BASE_CONFIG_PATH = LTX_VIDEO_REPO_DIR / "configs" | |
| DEFAULT_CONFIG_FILE = BASE_CONFIG_PATH / "ltxv-13b-0.9.8-dev-fp8.yaml" | |
| LTX_REPO_ID = "Lightricks/LTX-Video" | |
| RESULTS_DIR = Path("/app/output") | |
| DEFAULT_FPS = 24.0 | |
| FRAMES_ALIGNMENT = 8 | |
| # 2. Baixar os arquivos do modelo base | |
| print(f"=== Baixando snapshot do repositório base: {base_repo} ===") | |
| if True: | |
| if True: | |
| ckpt_path_str = hf_hub_download(repo_id=LTX_REPO_ID, filename=checkpoint_path, cache_dir=CACHE_DIR) | |
| ckpt_path = Path(ckpt_path_str) | |
| if not ckpt_path.is_file(): | |
| raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}") | |
| # 1. Carrega Metadados do Checkpoint | |
| with safe_open(ckpt_path, framework="pt") as f: | |
| metadata = f.metadata() or {} | |
| config_str = metadata.get("config", "{}") | |
| configs = json.loads(config_str) | |
| allowed_inference_steps = configs.get("allowed_inference_steps") | |
| # 2. Carrega os Componentes Individuais (todos na CPU) | |
| # O `.from_pretrained(ckpt_path)` é inteligente e carrega os pesos corretos do arquivo .safetensors. | |
| logging.info("Carregando VAE...") | |
| vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu") | |
| logging.info("Carregando Transformer...") | |
| transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu") | |
| logging.info("Carregando Scheduler...") | |
| scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) | |
| logging.info("Carregando Text Encoder e Tokenizer...") | |
| text_encoder_path = self.config["text_encoder_model_name_or_path"] | |
| text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu") | |
| tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer") | |
| patchifier = SymmetricPatchifier(patch_size=1) | |
| # 3. Define a precisão dos modelos (ainda na CPU, será aplicado na GPU depois) | |
| precision = self.config.get("precision", "bfloat16") | |
| if precision == "bfloat16": | |
| vae.to(torch.bfloat16) | |
| transformer.to(torch.bfloat16) | |
| text_encoder.to(torch.bfloat16) | |
| # 4. Monta o objeto do Pipeline com os componentes carregados | |
| logging.info("Montando o objeto LTXVideoPipeline...") | |
| submodel_dict = { | |
| "transformer": transformer, | |
| "patchifier": patchifier, | |
| "text_encoder": text_encoder, | |
| "tokenizer": tokenizer, | |
| "scheduler": scheduler, | |
| "vae": vae, | |
| "allowed_inference_steps": allowed_inference_steps, | |
| # Os prompt enhancers são opcionais e não são carregados por padrão para economizar memória | |
| "prompt_enhancer_image_caption_model": None, | |
| "prompt_enhancer_image_caption_processor": None, | |
| "prompt_enhancer_llm_model": None, | |
| "prompt_enhancer_llm_tokenizer": None, | |
| } | |
| pipeline = LTXConditionPipeline(**submodel_dict) | |
| # 4. Montar a pipeline principal | |
| pipeline.to(device) | |
| pipeline.vae.enable_tiling() | |
| # 5. Carregar a pipeline de upscale | |
| print(f"Carregando o upsampler espacial de: {upscaler_repo}") | |
| pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( | |
| upscaler_repo, vae=vae, torch_dtype=torch_dtype | |
| ) | |
| pipe_upsample.to(device) | |
| print("=== [Inicialização Concluída] Aplicação pronta. ===") | |
| # --- Lógica Principal da Geração de Vídeo --- | |
| def round_to_nearest_resolution_acceptable_by_vae(height, width, vae_temporal_compression_ratio): | |
| height = height - (height % vae_temporal_compression_ratio) | |
| width = width - (width % vae_temporal_compression_ratio) | |
| return height, width | |
| def prepare_and_generate_video( | |
| condition_image_1, condition_strength_1, condition_frame_index_1, | |
| condition_image_2, condition_strength_2, condition_frame_index_2, | |
| prompt, duration, negative_prompt, | |
| height, width, guidance_scale, seed, randomize_seed, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| try: | |
| # Lógica para agrupar as condições *dentro* da função | |
| # Cálculo de frames e resolução | |
| num_frames = int(duration * FPS) + 1 | |
| temporal_compression = pipeline.vae_temporal_compression_ratio | |
| num_frames = ((num_frames - 1) // temporal_compression) * temporal_compression + 1 | |
| downscale_factor = 2 / 3 | |
| downscaled_height = int(height * downscale_factor) | |
| downscaled_width = int(width * downscale_factor) | |
| downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae( | |
| downscaled_height, downscaled_width, pipeline.vae_temporal_compression_ratio | |
| ) | |
| conditions = [] | |
| if condition_image_1 is not None: | |
| condition_image_1 = ImageOps.fit(condition_image_1, (downscaled_width, downscaled_height), Image.LANCZOS) | |
| conditions.append(LTXVideoCondition( | |
| image=condition_image_1, | |
| strength=condition_strength_1, | |
| frame_index=int(condition_frame_index_1) | |
| )) | |
| if condition_image_2 is not None: | |
| condition_image_2 = ImageOps.fit(condition_image_2, (downscaled_width, downscaled_height), Image.LANCZOS) | |
| conditions.append(LTXVideoCondition( | |
| image=condition_image_2, | |
| strength=condition_strength_2, | |
| frame_index=int(condition_frame_index_2) | |
| )) | |
| pipeline_args = {} | |
| if conditions: | |
| call_kwargs["conditions"] = conditions | |
| # Manipulação da seed | |
| if randomize_seed: | |
| seed = random.randint(0, 2**32 - 1) | |
| if True: | |
| call_kwargs = { | |
| "prompt":prompt, | |
| "height": downscaled_height, | |
| "width": downscaled_width, | |
| "skip_initial_inference_steps": 3, | |
| "skip_final_inference_steps": 0, | |
| "num_inference_steps": 30, | |
| "negative_prompt": negative_prompt, | |
| "guidance_scale": CONFIG.get("guidance_scale", [1, 1, 6, 8, 6, 1, 1]), | |
| "stg_scale": CONFIG.get("stg_scale", [0, 0, 4, 4, 4, 2, 1]), | |
| "rescaling_scale": CONFIG.get("rescaling_scale", [1, 1, 0.5, 0.5, 1, 1, 1]), | |
| "skip_block_list": CONFIG.get("skip_block_list", [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]), | |
| "frame_rate": int(DEFAULT_FPS), | |
| "generator": torch.Generator().manual_seed(seed), | |
| "output_type": "np", | |
| "media_items": None, | |
| "decode_timestep": CONFIG.get("decode_timestep", 0.05), | |
| "decode_noise_scale": CONFIG.get("decode_noise_scale", 0.025), | |
| "is_video": True, | |
| "vae_per_channel_normalize": True, | |
| "offload_to_cpu": False, | |
| "enhance_prompt": False, | |
| "num_frames": num_frames, | |
| "downscale_factor": CONFIG.get("downscale_factor", 0.6666666), | |
| "rescaling_scale": CONFIG.get("rescaling_scale", [1, 1, 0.5, 0.5, 1, 1, 1]), | |
| "guidance_timesteps": CONFIG.get("guidance_timesteps", [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]), | |
| "skip_block_list": CONFIG.get("skip_block_list", [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]), | |
| "sampler": CONFIG.get("sampler", "from_checkpoint"), | |
| "precision": CONFIG.get("precision", "float8_e4m3fn"), | |
| "stochastic_sampling": CONFIG.get("stochastic_sampling", False), | |
| "cfg_star_rescale": CONFIG.get("cfg_star_rescale", True), | |
| } | |
| # ETAPA 1: Geração do vídeo em baixa resolução | |
| latents = pipeline(**call_kwargs).frames[0] | |
| # ETAPA 2: Upscale dos latentes | |
| #upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2 | |
| #upscaled_latents = pipe_upsample( | |
| # latents=latents, | |
| # output_type="latent" | |
| #).frames | |
| # ETAPA 3: Denoise final em alta resolução | |
| if False: | |
| final_video_frames_np = pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=upscaled_width, | |
| height=upscaled_height, | |
| num_frames=num_frames, | |
| denoise_strength=0.999, | |
| timesteps=[1000, 909, 725, 421, 0], | |
| latents=upscaled_latents, | |
| decode_timestep=0.05, | |
| decode_noise_scale=0.025, | |
| image_cond_noise_scale=0.0, | |
| guidance_scale=guidance_scale, | |
| guidance_rescale=0.7, | |
| generator=torch.Generator(device="cuda").manual_seed(seed), | |
| output_type="np", | |
| **pipeline_args | |
| ).frames[0] | |
| else: | |
| final_video_frames_np = latents | |
| # Exportação para arquivo MP4 | |
| video_uint8_frames = [(frame * 255).astype(np.uint8) for frame in final_video_frames_np] | |
| output_filename = "output.mp4" | |
| with imageio.get_writer(output_filename, fps=FPS, quality=8, macro_block_size=1) as writer: | |
| for frame_idx, frame_data in enumerate(video_uint8_frames): | |
| progress((frame_idx + 1) / len(video_uint8_frames), desc="Codificando frames do vídeo...") | |
| writer.append_data(frame_data) | |
| return output_filename, seed | |
| except Exception as e: | |
| print(f"Ocorreu um erro: {e}") | |
| return None, seed | |
| # --- Interface Gráfica com Gradio --- | |
| with gr.Blocks(theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"]), delete_cache=(60, 900)) as demo: | |
| gr.Markdown("# Geração de Vídeo com LTX\n**Crie vídeos a partir de texto e imagens de condição.**") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", placeholder="Descreva o vídeo que você quer gerar...", lines=3, value="O Coringa dançando em um quarto escuro, iluminação dramática.") | |
| with gr.Accordion("Imagem de Condição 1", open=True): | |
| condition_image_1 = gr.Image(label="Imagem 1", type="pil") | |
| with gr.Row(): | |
| condition_strength_1 = gr.Slider(label="Peso", minimum=0.0, maximum=1.0, step=0.05, value=1.0) | |
| condition_frame_index_1 = gr.Number(label="Frame", value=0, precision=0) | |
| with gr.Accordion("Imagem de Condição 2", open=False): | |
| condition_image_2 = gr.Image(label="Imagem 2", type="pil") | |
| with gr.Row(): | |
| condition_strength_2 = gr.Slider(label="Peso", minimum=0.0, maximum=1.0, step=0.05, value=1.0) | |
| condition_frame_index_2 = gr.Number(label="Frame", value=0, precision=0) | |
| duration = gr.Slider(label="Duração (s)", minimum=1.0, maximum=10.0, step=0.5, value=2) | |
| with gr.Accordion("Configurações Avançadas", open=False): | |
| negative_prompt = gr.Textbox(label="Prompt Negativo", lines=2, value="pior qualidade, embaçado, tremido, distorcido") | |
| with gr.Row(): | |
| height = gr.Slider(label="Altura", minimum=256, maximum=1536, step=32, value=768) | |
| width = gr.Slider(label="Largura", minimum=256, maximum=1536, step=32, value=1152) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider(label="Guidance", minimum=1.0, maximum=5.0, step=0.1, value=1.0) | |
| randomize_seed = gr.Checkbox(label="Seed Aleatória", value=True) | |
| seed = gr.Number(label="Seed", value=0, precision=0) | |
| generate_btn = gr.Button("Gerar Vídeo", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_video = gr.Video(label="Vídeo Gerado", height=400) | |
| generated_seed = gr.Number(label="Seed Utilizada", interactive=False) | |
| generate_btn.click( | |
| fn=prepare_and_generate_video, | |
| inputs=[ | |
| condition_image_1, condition_strength_1, condition_frame_index_1, | |
| condition_image_2, condition_strength_2, condition_frame_index_2, | |
| prompt, duration, negative_prompt, | |
| height, width, guidance_scale, seed, randomize_seed, | |
| ], | |
| outputs=[output_video, generated_seed] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860, debug=True, show_error=True) |