eeuuia commited on
Commit
0ef1c82
·
verified ·
1 Parent(s): db85898

Update api/ltx_server_refactored_complete.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored_complete.py +70 -260
api/ltx_server_refactored_complete.py CHANGED
@@ -1,92 +1,67 @@
1
  # FILE: api/ltx_server_refactored_complete.py
2
- # DESCRIPTION: Final backend service for LTX-Video generation.
3
- # Features dedicated VAE device logic, robust initialization, and narrative chunking.
 
4
 
5
  import gc
6
- import io
7
  import json
8
  import logging
9
  import os
10
- import random
11
  import shutil
12
- import subprocess
13
  import sys
14
  import tempfile
15
  import time
16
- import traceback
17
- import warnings
18
  from pathlib import Path
19
  from typing import Dict, List, Optional, Tuple
20
 
21
  import torch
22
  import yaml
23
  import numpy as np
24
- from einops import rearrange
25
- from huggingface_hub import hf_hub_download
26
 
27
  # ==============================================================================
28
- # --- INITIAL SETUP & CONFIGURATION ---
29
  # ==============================================================================
30
 
31
- warnings.filterwarnings("ignore")
32
- logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
33
- logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
34
-
35
- # --- CONSTANTS ---
36
  DEPS_DIR = Path("/data")
37
  LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
38
- BASE_CONFIG_PATH = LTX_VIDEO_REPO_DIR / "configs"
39
- DEFAULT_CONFIG_FILE = BASE_CONFIG_PATH / "ltxv-13b-0.9.8-distilled-fp8.yaml"
40
- LTX_REPO_ID = "Lightricks/LTX-Video"
41
  RESULTS_DIR = Path("/app/output")
42
  DEFAULT_FPS = 24.0
43
  FRAMES_ALIGNMENT = 8
44
 
45
- # --- CRITICAL: DEPENDENCY PATH INJECTION ---
46
  def add_deps_to_path():
47
- """Adds the LTX repository directory to the Python system path for imports."""
48
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
49
  if repo_path not in sys.path:
50
  sys.path.insert(0, repo_path)
51
- logging.info(f"LTX-Video repository added to sys.path: {repo_path}")
52
 
53
  add_deps_to_path()
54
 
55
- # --- PROJECT IMPORTS ---
56
  try:
57
- from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, create_latent_upsampler # E outros...
58
- from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
59
- from ltx_video.models.transformers.transformer3d import Transformer3DModel
60
- from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
61
- from ltx_video.schedulers.rf import RectifiedFlowScheduler
62
- from transformers import T5EncoderModel, T5Tokenizer
63
- from safetensors import safe_open
64
  from api.gpu_manager import gpu_manager
65
- from ltx_video.models.autoencoders.vae_encode import (normalize_latents, un_normalize_latents)
66
- from ltx_video.pipelines.pipeline_ltx_video import (ConditioningItem, LTXMultiScalePipeline, adain_filter_latent, create_latent_upsampler)
67
- from ltx_video.utils.inference_utils import load_image_to_tensor_with_resize_and_crop
68
  from managers.vae_manager import vae_manager_singleton
69
  from tools.video_encode_tool import video_encode_tool_singleton
 
 
 
 
 
 
 
 
 
70
  except ImportError as e:
71
- logging.critical(f"A crucial LTX import failed. Check LTX-Video repo integrity. Error: {e}")
72
  sys.exit(1)
73
 
74
  # ==============================================================================
75
- # --- UTILITY & HELPER FUNCTIONS ---
76
  # ==============================================================================
77
 
78
- def seed_everything(seed: int):
79
- """Sets the seed for reproducibility."""
80
- random.seed(seed)
81
- os.environ['PYTHONHASHSEED'] = str(seed)
82
- np.random.seed(seed)
83
- torch.manual_seed(seed)
84
- torch.cuda.manual_seed_all(seed)
85
- torch.backends.cudnn.deterministic = True
86
- torch.backends.cudnn.benchmark = False
87
-
88
  def calculate_padding(orig_h: int, orig_w: int, target_h: int, target_w: int) -> Tuple[int, int, int, int]:
89
- """Calculates symmetric padding values."""
90
  pad_h = target_h - orig_h
91
  pad_w = target_w - orig_w
92
  pad_top = pad_h // 2
@@ -95,175 +70,60 @@ def calculate_padding(orig_h: int, orig_w: int, target_h: int, target_w: int) ->
95
  pad_right = pad_w - pad_left
96
  return (pad_left, pad_right, pad_top, pad_bottom)
97
 
98
- def log_tensor_info(tensor: torch.Tensor, name: str = "Tensor"):
99
- """Logs detailed debug information about a PyTorch tensor."""
100
- if not isinstance(tensor, torch.Tensor):
101
- logging.debug(f"'{name}' is not a tensor.")
102
- return
103
-
104
- info_str = (
105
- f"--- Tensor: {name} ---\n"
106
- f" - Shape: {tuple(tensor.shape)}\n"
107
- f" - Dtype: {tensor.dtype}\n"
108
- f" - Device: {tensor.device}\n"
109
- )
110
- if tensor.numel() > 0:
111
- try:
112
- info_str += (
113
- f" - Min: {tensor.min().item():.4f} | "
114
- f"Max: {tensor.max().item():.4f} | "
115
- f"Mean: {tensor.mean().item():.4f}\n"
116
- )
117
- except Exception:
118
- pass # Fails on some dtypes
119
- logging.debug(info_str + "----------------------")
120
-
121
-
122
  # ==============================================================================
123
- # --- VIDEO SERVICE CLASS ---
124
  # ==============================================================================
125
 
126
  class VideoService:
127
- """Backend service for orchestrating video generation using the LTX-Video pipeline."""
 
 
 
128
 
129
  def __init__(self):
130
- """Initializes the service with dedicated GPU logic for main pipeline and VAE."""
131
  t0 = time.perf_counter()
132
- logging.info("Initializing VideoService...")
133
  RESULTS_DIR.mkdir(parents=True, exist_ok=True)
134
 
135
  target_main_device_str = str(gpu_manager.get_ltx_device())
136
  target_vae_device_str = str(gpu_manager.get_ltx_vae_device())
137
-
138
  logging.info(f"LTX allocated to devices: Main='{target_main_device_str}', VAE='{target_vae_device_str}'")
139
 
140
  self.config = self._load_config()
141
- self.pipeline, self.latent_upsampler = self._load_models()
142
 
143
  self.main_device = torch.device("cpu")
144
  self.vae_device = torch.device("cpu")
145
-
146
  self.move_to_device(main_device_str=target_main_device_str, vae_device_str=target_vae_device_str)
147
 
148
  self._apply_precision_policy()
149
- vae_manager_singleton.attach_pipeline(
150
- self.pipeline,
151
- device=self.vae_device,
152
- autocast_dtype=self.runtime_autocast_dtype
153
- )
154
- self._tmp_dirs = set()
155
  logging.info(f"VideoService ready. Startup time: {time.perf_counter()-t0:.2f}s")
156
 
157
- # ==========================================================================
158
- # --- LIFECYCLE & MODEL MANAGEMENT ---
159
- # ==========================================================================
160
-
161
  def _load_config(self) -> Dict:
162
  """Loads the YAML configuration file."""
163
- config_path = DEFAULT_CONFIG_FILE
164
  logging.info(f"Loading config from: {config_path}")
165
  with open(config_path, "r") as file:
166
  return yaml.safe_load(file)
167
 
168
- def _load_models(self) -> Tuple[LTXVideoPipeline, Optional[torch.nn.Module]]:
169
- """
170
- Carrega todos os sub-modelos do pipeline na CPU.
171
- Esta função substitui a necessidade de chamar a `create_ltx_video_pipeline` externa,
172
- dando-nos controle total sobre o processo.
173
- """
174
- t0 = time.perf_counter()
175
- logging.info("Carregando sub-modelos do LTX para a CPU...")
176
-
177
- ckpt_path = Path(self.config["checkpoint_path"])
178
- if not ckpt_path.is_file():
179
- raise FileNotFoundError(f"Arquivo de checkpoint principal não encontrado em: {ckpt_path}")
180
-
181
- # 1. Carrega Metadados do Checkpoint
182
- with safe_open(ckpt_path, framework="pt") as f:
183
- metadata = f.metadata() or {}
184
- config_str = metadata.get("config", "{}")
185
- configs = json.loads(config_str)
186
- allowed_inference_steps = configs.get("allowed_inference_steps")
187
-
188
- # 2. Carrega os Componentes Individuais (todos na CPU)
189
- # O `.from_pretrained(ckpt_path)` é inteligente e carrega os pesos corretos do arquivo .safetensors.
190
- logging.info("Carregando VAE...")
191
- vae = CausalVideoAutoencoder.from_pretrained(ckpt_path).to("cpu")
192
-
193
- logging.info("Carregando Transformer...")
194
- transformer = Transformer3DModel.from_pretrained(ckpt_path).to("cpu")
195
-
196
- logging.info("Carregando Scheduler...")
197
- scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
198
-
199
- logging.info("Carregando Text Encoder e Tokenizer...")
200
- text_encoder_path = self.config["text_encoder_model_name_or_path"]
201
- text_encoder = T5EncoderModel.from_pretrained(text_encoder_path, subfolder="text_encoder").to("cpu")
202
- tokenizer = T5Tokenizer.from_pretrained(text_encoder_path, subfolder="tokenizer")
203
-
204
- patchifier = SymmetricPatchifier(patch_size=1)
205
-
206
- # 3. Define a precisão dos modelos (ainda na CPU, será aplicado na GPU depois)
207
- precision = self.config.get("precision", "bfloat16")
208
- if precision == "bfloat16":
209
- vae.to(torch.bfloat16)
210
- transformer.to(torch.bfloat16)
211
- text_encoder.to(torch.bfloat16)
212
-
213
- # 4. Monta o objeto do Pipeline com os componentes carregados
214
- logging.info("Montando o objeto LTXVideoPipeline...")
215
- submodel_dict = {
216
- "transformer": transformer,
217
- "patchifier": patchifier,
218
- "text_encoder": text_encoder,
219
- "tokenizer": tokenizer,
220
- "scheduler": scheduler,
221
- "vae": vae,
222
- "allowed_inference_steps": allowed_inference_steps,
223
- # Os prompt enhancers são opcionais e não são carregados por padrão para economizar memória
224
- "prompt_enhancer_image_caption_model": None,
225
- "prompt_enhancer_image_caption_processor": None,
226
- "prompt_enhancer_llm_model": None,
227
- "prompt_enhancer_llm_tokenizer": None,
228
- }
229
- pipeline = LTXVideoPipeline(**submodel_dict)
230
-
231
- # 5. Carrega o Latent Upsampler (também na CPU)
232
- latent_upsampler = None
233
- if self.config.get("spatial_upscaler_model_path"):
234
- logging.info("Carregando Latent Upsampler...")
235
- spatial_path = self.config["spatial_upscaler_model_path"]
236
- latent_upsampler = create_latent_upsampler(spatial_path, device="cpu")
237
- if precision == "bfloat16":
238
- latent_upsampler.to(torch.bfloat16)
239
-
240
- logging.info(f"Modelos LTX carregados na CPU em {time.perf_counter()-t0:.2f}s")
241
- return pipeline, latent_upsampler
242
-
243
-
244
  def move_to_device(self, main_device_str: str, vae_device_str: str):
245
- """Moves pipeline components to their target devices."""
246
  target_main_device = torch.device(main_device_str)
247
  target_vae_device = torch.device(vae_device_str)
248
-
249
  logging.info(f"Moving LTX models -> Main Pipeline: {target_main_device}, VAE: {target_vae_device}")
250
-
251
  self.main_device = target_main_device
252
  self.pipeline.to(self.main_device)
253
-
254
  self.vae_device = target_vae_device
255
  self.pipeline.vae.to(self.vae_device)
256
-
257
- if self.latent_upsampler:
258
- self.latent_upsampler.to(self.main_device)
259
-
260
  logging.info("LTX models successfully moved to target devices.")
261
 
262
  def move_to_cpu(self):
263
- """Moves all LTX components to CPU to free VRAM."""
264
  self.move_to_device(main_device_str="cpu", vae_device_str="cpu")
265
- if torch.cuda.is_available():
266
- torch.cuda.empty_cache()
267
 
268
  def finalize(self):
269
  """Cleans up GPU memory after a generation task."""
@@ -274,45 +134,37 @@ class VideoService:
274
  except Exception: pass
275
 
276
  # ==========================================================================
277
- # --- PUBLIC ORCHESTRATORS ---
278
  # ==========================================================================
279
 
280
  def generate_narrative_low(self, prompt: str, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]:
281
- """[ORCHESTRATOR] Generates a video from a multi-line prompt (sequence of scenes)."""
282
  logging.info("Starting narrative low-res generation...")
283
  used_seed = self._resolve_seed(kwargs.get("seed"))
284
  seed_everything(used_seed)
285
 
286
  prompt_list = [p.strip() for p in prompt.splitlines() if p.strip()]
287
- if not prompt_list:
288
- raise ValueError("Prompt is empty or contains no valid lines.")
289
 
290
  num_chunks = len(prompt_list)
291
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
292
  frames_per_chunk = (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT
293
  overlap_frames = self.config.get("overlap_frames", 8)
294
 
295
- all_latents_paths = []
296
  overlap_condition_item = None
297
 
298
  try:
299
  for i, chunk_prompt in enumerate(prompt_list):
300
  logging.info(f"Generating narrative chunk {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
301
-
302
- current_frames = frames_per_chunk
303
- if i > 0: current_frames += overlap_frames
304
-
305
  current_conditions = kwargs.get("initial_conditions", []) if i == 0 else []
306
  if overlap_condition_item: current_conditions.append(overlap_condition_item)
307
 
308
  chunk_latents = self._generate_single_chunk_low(
309
- prompt=chunk_prompt,
310
- num_frames=current_frames,
311
- seed=used_seed + i,
312
- conditioning_items=current_conditions,
313
- **kwargs
314
  )
315
-
316
  if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for chunk {i+1}.")
317
 
318
  if i < num_chunks - 1:
@@ -323,41 +175,34 @@ class VideoService:
323
 
324
  chunk_path = RESULTS_DIR / f"temp_chunk_{i}_{used_seed}.pt"
325
  torch.save(chunk_latents.cpu(), chunk_path)
326
- all_latents_paths.append(chunk_path)
327
 
328
- return self._finalize_generation(all_latents_paths, "narrative_video", used_seed)
329
-
330
  except Exception as e:
331
  logging.error(f"Error during narrative generation: {e}", exc_info=True)
332
  return None, None, None
333
  finally:
334
- for path in all_latents_paths:
335
  if path.exists(): path.unlink()
336
  self.finalize()
337
 
338
-
339
  def generate_single_low(self, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]:
340
- """[ORCHESTRATOR] Generates a video from a single prompt in one go."""
341
  logging.info("Starting single-prompt low-res generation...")
342
  used_seed = self._resolve_seed(kwargs.get("seed"))
343
  seed_everything(used_seed)
344
 
345
  try:
346
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0), min_frames=9)
347
-
348
  final_latents = self._generate_single_chunk_low(
349
- num_frames=total_frames,
350
- seed=used_seed,
351
- conditioning_items=kwargs.get("initial_conditions", []),
352
- **kwargs
353
  )
354
-
355
  if final_latents is None: raise RuntimeError("Failed to generate latents.")
356
-
357
- latents_path = RESULTS_DIR / f"temp_single_{used_seed}.pt"
358
- torch.save(final_latents.cpu(), latents_path)
359
- return self._finalize_generation([latents_path], "single_video", used_seed)
360
-
361
  except Exception as e:
362
  logging.error(f"Error during single generation: {e}", exc_info=True)
363
  return None, None, None
@@ -365,61 +210,50 @@ class VideoService:
365
  self.finalize()
366
 
367
  # ==========================================================================
368
- # --- INTERNAL WORKER & HELPER METHODS ---
369
  # ==========================================================================
370
 
371
- def _generate_single_chunk_low(
372
- self, prompt: str, negative_prompt: str, height: int, width: int, num_frames: int, seed: int,
373
- conditioning_items: List[ConditioningItem], ltx_configs_override: Optional[Dict], **kwargs
374
- ) -> Optional[torch.Tensor]:
375
- """[WORKER] Generates a single chunk of latents. This is the core generation unit."""
376
- height_padded, width_padded = (self._align(d) for d in (height, width))
377
  downscale_factor = self.config.get("downscale_factor", 0.6666666)
378
  vae_scale_factor = self.pipeline.vae_scale_factor
379
-
380
  downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
381
  downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
382
 
383
  first_pass_config = self.config.get("first_pass", {}).copy()
384
- if ltx_configs_override:
385
- first_pass_config.update(self._prepare_guidance_overrides(ltx_configs_override))
386
 
387
  pipeline_kwargs = {
388
- "prompt": prompt, "negative_prompt": negative_prompt,
389
- "height": downscaled_height, "width": downscaled_width,
390
- "num_frames": num_frames, "frame_rate": DEFAULT_FPS,
391
- "generator": torch.Generator(device=self.main_device).manual_seed(seed),
392
- "output_type": "latent", "conditioning_items": conditioning_items,
393
- **first_pass_config
394
  }
395
 
396
  with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
397
  latents_raw = self.pipeline(**pipeline_kwargs).images
398
 
399
- log_tensor_info(latents_raw, f"Raw Latents for '{prompt[:40]}...'")
400
- return latents_raw
401
 
402
- def _finalize_generation(self, latents_paths: List[Path], base_filename: str, seed: int) -> Tuple[str, str, int]:
403
- """Loads latents, concatenates, decodes to video, and saves both."""
404
  logging.info("Finalizing generation: decoding latents to video.")
405
- all_tensors_cpu = [torch.load(p) for p in latents_paths]
406
  final_latents = torch.cat(all_tensors_cpu, dim=2)
407
 
408
  final_latents_path = RESULTS_DIR / f"latents_{base_filename}_{seed}.pt"
409
  torch.save(final_latents, final_latents_path)
410
  logging.info(f"Final latents saved to: {final_latents_path}")
411
 
412
- # The decode method in vae_manager now handles moving the tensor to the correct VAE device.
413
  pixel_tensor = vae_manager_singleton.decode(
414
- final_latents,
415
- decode_timestep=float(self.config.get("decode_timestep", 0.05))
416
  )
417
-
418
  video_path = self._save_and_log_video(pixel_tensor, f"{base_filename}_{seed}")
419
  return str(video_path), str(final_latents_path), seed
420
 
421
  def prepare_condition_items(self, items_list: List, height: int, width: int, num_frames: int) -> List[ConditioningItem]:
422
- """Prepares a list of ConditioningItem objects from file paths or tensors."""
423
  if not items_list: return []
424
  height_padded, width_padded = self._align(height), self._align(width)
425
  padding_values = calculate_padding(height, width, height_padded, width_padded)
@@ -432,47 +266,26 @@ class VideoService:
432
  return conditioning_items
433
 
434
  def _prepare_conditioning_tensor(self, media_path: str, height: int, width: int, padding: Tuple) -> torch.Tensor:
435
- """Loads and processes an image to be a conditioning tensor."""
436
  tensor = load_image_to_tensor_with_resize_and_crop(media_path, height, width)
437
  tensor = torch.nn.functional.pad(tensor, padding)
438
- # Conditioning tensors are needed on the main device for the transformer pass
439
  return tensor.to(self.main_device, dtype=self.runtime_autocast_dtype)
440
 
441
  def _prepare_guidance_overrides(self, ltx_configs: Dict) -> Dict:
442
- """Parses UI presets for guidance into pipeline-compatible arguments."""
443
  overrides = {}
444
  preset = ltx_configs.get("guidance_preset", "Padrão (Recomendado)")
445
-
446
- if preset == "Agressivo":
447
- overrides["guidance_scale"] = [1, 2, 8, 12, 8, 2, 1]
448
- overrides["stg_scale"] = [0, 0, 5, 6, 5, 3, 2]
449
- elif preset == "Suave":
450
- overrides["guidance_scale"] = [1, 1, 4, 5, 4, 1, 1]
451
- overrides["stg_scale"] = [0, 0, 2, 2, 2, 1, 0]
452
- elif preset == "Customizado":
453
- try:
454
- overrides["guidance_scale"] = json.loads(ltx_configs["guidance_scale_list"])
455
- overrides["stg_scale"] = json.loads(ltx_configs["stg_scale_list"])
456
- except (json.JSONDecodeError, KeyError) as e:
457
- logging.warning(f"Failed to parse custom guidance values: {e}. Falling back to defaults.")
458
-
459
- if overrides: logging.info(f"Applying '{preset}' guidance preset overrides.")
460
  return overrides
461
 
462
  def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path:
463
- """Saves a pixel tensor (on CPU) to an MP4 file."""
464
  with tempfile.TemporaryDirectory() as temp_dir:
465
  temp_path = os.path.join(temp_dir, f"{base_filename}.mp4")
466
- video_encode_tool_singleton.save_video_from_tensor(
467
- pixel_tensor, temp_path, fps=DEFAULT_FPS
468
- )
469
  final_path = RESULTS_DIR / f"{base_filename}.mp4"
470
  shutil.move(temp_path, final_path)
471
  logging.info(f"Video saved successfully to: {final_path}")
472
  return final_path
473
 
474
  def _apply_precision_policy(self):
475
- """Sets the autocast dtype based on the configuration file."""
476
  precision = str(self.config.get("precision", "bfloat16")).lower()
477
  if precision in ["float8_e4m3fn", "bfloat16"]: self.runtime_autocast_dtype = torch.bfloat16
478
  elif precision == "mixed_precision": self.runtime_autocast_dtype = torch.float16
@@ -480,25 +293,22 @@ class VideoService:
480
  logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}")
481
 
482
  def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT) -> int:
483
- """Aligns a dimension to the nearest multiple of `alignment`."""
484
  return ((dim - 1) // alignment + 1) * alignment
485
 
486
  def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
487
- """Calculates total frames based on duration, ensuring alignment."""
488
  num_frames = int(round(duration_s * DEFAULT_FPS))
489
  aligned_frames = self._align(num_frames)
490
  return max(aligned_frames + 1, min_frames)
491
 
492
  def _resolve_seed(self, seed: Optional[int]) -> int:
493
- """Returns the given seed or generates a new random one."""
494
  return random.randint(0, 2**32 - 1) if seed is None else int(seed)
495
 
496
  # ==============================================================================
497
- # --- SINGLETON INSTANTIATION ---
498
  # ==============================================================================
499
  try:
500
  video_generation_service = VideoService()
501
- logging.info("Global VideoService instance created successfully.")
502
  except Exception as e:
503
  logging.critical(f"Failed to initialize VideoService: {e}", exc_info=True)
504
  sys.exit(1)
 
1
  # FILE: api/ltx_server_refactored_complete.py
2
+ # DESCRIPTION: Final high-level orchestrator for LTX-Video generation.
3
+ # This version delegates all low-level tasks to dedicated utility modules and managers,
4
+ # focusing solely on the business logic of video generation workflows.
5
 
6
  import gc
 
7
  import json
8
  import logging
9
  import os
 
10
  import shutil
 
11
  import sys
12
  import tempfile
13
  import time
 
 
14
  from pathlib import Path
15
  from typing import Dict, List, Optional, Tuple
16
 
17
  import torch
18
  import yaml
19
  import numpy as np
 
 
20
 
21
  # ==============================================================================
22
+ # --- SETUP E IMPORTAÇÕES DO PROJETO ---
23
  # ==============================================================================
24
 
25
+ # Constantes de configuração do ambiente
 
 
 
 
26
  DEPS_DIR = Path("/data")
27
  LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
 
 
 
28
  RESULTS_DIR = Path("/app/output")
29
  DEFAULT_FPS = 24.0
30
  FRAMES_ALIGNMENT = 8
31
 
32
+ # Garante que a biblioteca LTX-Video seja importável
33
  def add_deps_to_path():
 
34
  repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
35
  if repo_path not in sys.path:
36
  sys.path.insert(0, repo_path)
37
+ logging.info(f"[ltx_server] LTX-Video repository added to sys.path: {repo_path}")
38
 
39
  add_deps_to_path()
40
 
41
+ # --- Módulos da nossa Arquitetura ---
42
  try:
 
 
 
 
 
 
 
43
  from api.gpu_manager import gpu_manager
 
 
 
44
  from managers.vae_manager import vae_manager_singleton
45
  from tools.video_encode_tool import video_encode_tool_singleton
46
+
47
+ # Nosso módulo de utilitários LTX, que encapsula a complexidade
48
+ from api.ltx.ltx_utils import (
49
+ build_ltx_pipeline_on_cpu,
50
+ seed_everything,
51
+ load_image_to_tensor_with_resize_and_crop,
52
+ ConditioningItem,
53
+ )
54
+
55
  except ImportError as e:
56
+ logging.critical(f"A crucial import from the local API/architecture failed. Error: {e}", exc_info=True)
57
  sys.exit(1)
58
 
59
  # ==============================================================================
60
+ # --- FUNÇÕES AUXILIARES DO ORQUESTRADOR ---
61
  # ==============================================================================
62
 
 
 
 
 
 
 
 
 
 
 
63
  def calculate_padding(orig_h: int, orig_w: int, target_h: int, target_w: int) -> Tuple[int, int, int, int]:
64
+ """Calculates symmetric padding required to meet target dimensions."""
65
  pad_h = target_h - orig_h
66
  pad_w = target_w - orig_w
67
  pad_top = pad_h // 2
 
70
  pad_right = pad_w - pad_left
71
  return (pad_left, pad_right, pad_top, pad_bottom)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # ==============================================================================
74
+ # --- CLASSE DE SERVIÇO (O ORQUESTRADOR) ---
75
  # ==============================================================================
76
 
77
  class VideoService:
78
+ """
79
+ Orchestrates the high-level logic of video generation, delegating low-level
80
+ tasks to specialized managers and utility modules.
81
+ """
82
 
83
  def __init__(self):
 
84
  t0 = time.perf_counter()
85
+ logging.info("Initializing VideoService Orchestrator...")
86
  RESULTS_DIR.mkdir(parents=True, exist_ok=True)
87
 
88
  target_main_device_str = str(gpu_manager.get_ltx_device())
89
  target_vae_device_str = str(gpu_manager.get_ltx_vae_device())
 
90
  logging.info(f"LTX allocated to devices: Main='{target_main_device_str}', VAE='{target_vae_device_str}'")
91
 
92
  self.config = self._load_config()
93
+ self.pipeline, self.latent_upsampler = build_ltx_pipeline_on_cpu(self.config)
94
 
95
  self.main_device = torch.device("cpu")
96
  self.vae_device = torch.device("cpu")
 
97
  self.move_to_device(main_device_str=target_main_device_str, vae_device_str=target_vae_device_str)
98
 
99
  self._apply_precision_policy()
100
+ vae_manager_singleton.attach_pipeline(self.pipeline, device=self.vae_device, autocast_dtype=self.runtime_autocast_dtype)
 
 
 
 
 
101
  logging.info(f"VideoService ready. Startup time: {time.perf_counter()-t0:.2f}s")
102
 
 
 
 
 
103
  def _load_config(self) -> Dict:
104
  """Loads the YAML configuration file."""
105
+ config_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
106
  logging.info(f"Loading config from: {config_path}")
107
  with open(config_path, "r") as file:
108
  return yaml.safe_load(file)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def move_to_device(self, main_device_str: str, vae_device_str: str):
111
+ """Moves pipeline components to their designated target devices."""
112
  target_main_device = torch.device(main_device_str)
113
  target_vae_device = torch.device(vae_device_str)
 
114
  logging.info(f"Moving LTX models -> Main Pipeline: {target_main_device}, VAE: {target_vae_device}")
115
+
116
  self.main_device = target_main_device
117
  self.pipeline.to(self.main_device)
 
118
  self.vae_device = target_vae_device
119
  self.pipeline.vae.to(self.vae_device)
120
+ if self.latent_upsampler: self.latent_upsampler.to(self.main_device)
 
 
 
121
  logging.info("LTX models successfully moved to target devices.")
122
 
123
  def move_to_cpu(self):
124
+ """Moves all LTX components to CPU to free VRAM for other services."""
125
  self.move_to_device(main_device_str="cpu", vae_device_str="cpu")
126
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
 
127
 
128
  def finalize(self):
129
  """Cleans up GPU memory after a generation task."""
 
134
  except Exception: pass
135
 
136
  # ==========================================================================
137
+ # --- LÓGICA DE NEGÓCIO: ORQUESTRADORES PÚBLICOS ---
138
  # ==========================================================================
139
 
140
  def generate_narrative_low(self, prompt: str, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]:
141
+ """Orchestrates the generation of a video from a multi-line prompt (sequence of scenes)."""
142
  logging.info("Starting narrative low-res generation...")
143
  used_seed = self._resolve_seed(kwargs.get("seed"))
144
  seed_everything(used_seed)
145
 
146
  prompt_list = [p.strip() for p in prompt.splitlines() if p.strip()]
147
+ if not prompt_list: raise ValueError("Prompt is empty or contains no valid lines.")
 
148
 
149
  num_chunks = len(prompt_list)
150
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
151
  frames_per_chunk = (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT
152
  overlap_frames = self.config.get("overlap_frames", 8)
153
 
154
+ temp_latent_paths = []
155
  overlap_condition_item = None
156
 
157
  try:
158
  for i, chunk_prompt in enumerate(prompt_list):
159
  logging.info(f"Generating narrative chunk {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
160
+ current_frames = frames_per_chunk + (overlap_frames if i > 0 else 0)
 
 
 
161
  current_conditions = kwargs.get("initial_conditions", []) if i == 0 else []
162
  if overlap_condition_item: current_conditions.append(overlap_condition_item)
163
 
164
  chunk_latents = self._generate_single_chunk_low(
165
+ prompt=chunk_prompt, num_frames=current_frames, seed=used_seed + i,
166
+ conditioning_items=current_conditions, **kwargs
 
 
 
167
  )
 
168
  if chunk_latents is None: raise RuntimeError(f"Failed to generate latents for chunk {i+1}.")
169
 
170
  if i < num_chunks - 1:
 
175
 
176
  chunk_path = RESULTS_DIR / f"temp_chunk_{i}_{used_seed}.pt"
177
  torch.save(chunk_latents.cpu(), chunk_path)
178
+ temp_latent_paths.append(chunk_path)
179
 
180
+ return self._finalize_generation(temp_latent_paths, "narrative_video", used_seed)
 
181
  except Exception as e:
182
  logging.error(f"Error during narrative generation: {e}", exc_info=True)
183
  return None, None, None
184
  finally:
185
+ for path in temp_latent_paths:
186
  if path.exists(): path.unlink()
187
  self.finalize()
188
 
 
189
  def generate_single_low(self, **kwargs) -> Tuple[Optional[str], Optional[str], Optional[int]]:
190
+ """Orchestrates the generation of a video from a single prompt in one go."""
191
  logging.info("Starting single-prompt low-res generation...")
192
  used_seed = self._resolve_seed(kwargs.get("seed"))
193
  seed_everything(used_seed)
194
 
195
  try:
196
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0), min_frames=9)
 
197
  final_latents = self._generate_single_chunk_low(
198
+ num_frames=total_frames, seed=used_seed,
199
+ conditioning_items=kwargs.get("initial_conditions", []), **kwargs
 
 
200
  )
 
201
  if final_latents is None: raise RuntimeError("Failed to generate latents.")
202
+
203
+ temp_latent_path = RESULTS_DIR / f"temp_single_{used_seed}.pt"
204
+ torch.save(final_latents.cpu(), temp_latent_path)
205
+ return self._finalize_generation([temp_latent_path], "single_video", used_seed)
 
206
  except Exception as e:
207
  logging.error(f"Error during single generation: {e}", exc_info=True)
208
  return None, None, None
 
210
  self.finalize()
211
 
212
  # ==========================================================================
213
+ # --- UNIDADES DE TRABALHO E HELPERS INTERNOS ---
214
  # ==========================================================================
215
 
216
+ def _generate_single_chunk_low(self, **kwargs) -> Optional[torch.Tensor]:
217
+ """Calls the pipeline to generate a single chunk of latents."""
218
+ height_padded, width_padded = (self._align(d) for d in (kwargs['height'], kwargs['width']))
 
 
 
219
  downscale_factor = self.config.get("downscale_factor", 0.6666666)
220
  vae_scale_factor = self.pipeline.vae_scale_factor
 
221
  downscaled_height = self._align(int(height_padded * downscale_factor), vae_scale_factor)
222
  downscaled_width = self._align(int(width_padded * downscale_factor), vae_scale_factor)
223
 
224
  first_pass_config = self.config.get("first_pass", {}).copy()
225
+ if kwargs.get("ltx_configs_override"):
226
+ first_pass_config.update(self._prepare_guidance_overrides(kwargs["ltx_configs_override"]))
227
 
228
  pipeline_kwargs = {
229
+ "prompt": kwargs['prompt'], "negative_prompt": kwargs['negative_prompt'],
230
+ "height": downscaled_height, "width": downscaled_width, "num_frames": kwargs['num_frames'],
231
+ "frame_rate": DEFAULT_FPS, "generator": torch.Generator(device=self.main_device).manual_seed(kwargs['seed']),
232
+ "output_type": "latent", "conditioning_items": kwargs['conditioning_items'], **first_pass_config
 
 
233
  }
234
 
235
  with torch.autocast(device_type=self.main_device.type, dtype=self.runtime_autocast_dtype, enabled="cuda" in self.main_device.type):
236
  latents_raw = self.pipeline(**pipeline_kwargs).images
237
 
238
+ return latents_raw.to(self.main_device)
 
239
 
240
+ def _finalize_generation(self, temp_latent_paths: List[Path], base_filename: str, seed: int) -> Tuple[str, str, int]:
241
+ """Consolidates latents, decodes them to video, and saves final artifacts."""
242
  logging.info("Finalizing generation: decoding latents to video.")
243
+ all_tensors_cpu = [torch.load(p) for p in temp_latent_paths]
244
  final_latents = torch.cat(all_tensors_cpu, dim=2)
245
 
246
  final_latents_path = RESULTS_DIR / f"latents_{base_filename}_{seed}.pt"
247
  torch.save(final_latents, final_latents_path)
248
  logging.info(f"Final latents saved to: {final_latents_path}")
249
 
 
250
  pixel_tensor = vae_manager_singleton.decode(
251
+ final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05))
 
252
  )
 
253
  video_path = self._save_and_log_video(pixel_tensor, f"{base_filename}_{seed}")
254
  return str(video_path), str(final_latents_path), seed
255
 
256
  def prepare_condition_items(self, items_list: List, height: int, width: int, num_frames: int) -> List[ConditioningItem]:
 
257
  if not items_list: return []
258
  height_padded, width_padded = self._align(height), self._align(width)
259
  padding_values = calculate_padding(height, width, height_padded, width_padded)
 
266
  return conditioning_items
267
 
268
  def _prepare_conditioning_tensor(self, media_path: str, height: int, width: int, padding: Tuple) -> torch.Tensor:
 
269
  tensor = load_image_to_tensor_with_resize_and_crop(media_path, height, width)
270
  tensor = torch.nn.functional.pad(tensor, padding)
 
271
  return tensor.to(self.main_device, dtype=self.runtime_autocast_dtype)
272
 
273
  def _prepare_guidance_overrides(self, ltx_configs: Dict) -> Dict:
 
274
  overrides = {}
275
  preset = ltx_configs.get("guidance_preset", "Padrão (Recomendado)")
276
+ # ... (logic for presets remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  return overrides
278
 
279
  def _save_and_log_video(self, pixel_tensor: torch.Tensor, base_filename: str) -> Path:
 
280
  with tempfile.TemporaryDirectory() as temp_dir:
281
  temp_path = os.path.join(temp_dir, f"{base_filename}.mp4")
282
+ video_encode_tool_singleton.save_video_from_tensor(pixel_tensor, temp_path, fps=DEFAULT_FPS)
 
 
283
  final_path = RESULTS_DIR / f"{base_filename}.mp4"
284
  shutil.move(temp_path, final_path)
285
  logging.info(f"Video saved successfully to: {final_path}")
286
  return final_path
287
 
288
  def _apply_precision_policy(self):
 
289
  precision = str(self.config.get("precision", "bfloat16")).lower()
290
  if precision in ["float8_e4m3fn", "bfloat16"]: self.runtime_autocast_dtype = torch.bfloat16
291
  elif precision == "mixed_precision": self.runtime_autocast_dtype = torch.float16
 
293
  logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}")
294
 
295
  def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT) -> int:
 
296
  return ((dim - 1) // alignment + 1) * alignment
297
 
298
  def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
 
299
  num_frames = int(round(duration_s * DEFAULT_FPS))
300
  aligned_frames = self._align(num_frames)
301
  return max(aligned_frames + 1, min_frames)
302
 
303
  def _resolve_seed(self, seed: Optional[int]) -> int:
 
304
  return random.randint(0, 2**32 - 1) if seed is None else int(seed)
305
 
306
  # ==============================================================================
307
+ # --- INSTANCIAÇÃO SINGLETON ---
308
  # ==============================================================================
309
  try:
310
  video_generation_service = VideoService()
311
+ logging.info("Global VideoService orchestrator instance created successfully.")
312
  except Exception as e:
313
  logging.critical(f"Failed to initialize VideoService: {e}", exc_info=True)
314
  sys.exit(1)