eeuuia commited on
Commit
c1fb4ff
·
verified ·
1 Parent(s): 02e3a82

Update api/ltx_server_refactored_complete.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored_complete.py +31 -35
api/ltx_server_refactored_complete.py CHANGED
@@ -1,7 +1,7 @@
1
  # FILE: api/ltx_server_refactored_complete.py
2
- # DESCRIPTION: Final high-level orchestrator for LTX-Video generation.
3
- # This version features a unified generation workflow, random seed generation,
4
- # delegation to specialized modules, and advanced debugging capabilities.
5
 
6
  import gc
7
  import json
@@ -36,7 +36,7 @@ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
36
  RESULTS_DIR = Path("/app/output")
37
  DEFAULT_FPS = 24.0
38
  FRAMES_ALIGNMENT = 8
39
- LTX_REPO_ID = "Lightricks/LTX-Video" # Repositório de onde os modelos são baixados
40
 
41
  # Garante que a biblioteca LTX-Video seja importável
42
  def add_deps_to_path():
@@ -45,7 +45,7 @@ def add_deps_to_path():
45
  sys.path.insert(0, repo_path)
46
  logging.info(f"[ltx_server] LTX-Video repository added to sys.path: {repo_path}")
47
 
48
- #add_deps_to_path()
49
 
50
  # --- Módulos da nossa Arquitetura ---
51
  try:
@@ -99,7 +99,7 @@ class VideoService:
99
  logging.info(f"LTX allocated to devices: Main='{target_main_device_str}', VAE='{target_vae_device_str}'")
100
 
101
  self.config = self._load_config()
102
- self._resolve_model_paths_from_cache() # Etapa crítica para encontrar os modelos
103
 
104
  self.pipeline, self.latent_upsampler = build_ltx_pipeline_on_cpu(self.config)
105
 
@@ -119,31 +119,16 @@ class VideoService:
119
  return yaml.safe_load(file)
120
 
121
  def _resolve_model_paths_from_cache(self):
122
- """
123
- Uses hf_hub_download to find the absolute paths to model files in the cache,
124
- updating the in-memory config. This makes the app resilient to cache structure.
125
- """
126
  logging.info("Resolving model paths from Hugging Face cache...")
127
  cache_dir = os.environ.get("HF_HOME")
128
  try:
129
- # Resolve o caminho do checkpoint principal
130
- main_ckpt_filename = self.config["checkpoint_path"]
131
- main_ckpt_path = hf_hub_download(
132
- repo_id=LTX_REPO_ID,
133
- filename=main_ckpt_filename,
134
- cache_dir=cache_dir
135
- )
136
  self.config["checkpoint_path"] = main_ckpt_path
137
  logging.info(f" -> Main checkpoint resolved to: {main_ckpt_path}")
138
 
139
- # Resolve o caminho do upsampler, se existir
140
  if self.config.get("spatial_upscaler_model_path"):
141
- upscaler_filename = self.config["spatial_upscaler_model_path"]
142
- upscaler_path = hf_hub_download(
143
- repo_id=LTX_REPO_ID,
144
- filename=upscaler_filename,
145
- cache_dir=cache_dir
146
- )
147
  self.config["spatial_upscaler_model_path"] = upscaler_path
148
  logging.info(f" -> Spatial upscaler resolved to: {upscaler_path}")
149
  except Exception as e:
@@ -201,7 +186,11 @@ class VideoService:
201
  num_chunks = len(prompt_list)
202
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
203
  frames_per_chunk = max(FRAMES_ALIGNMENT, (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT)
204
- overlap_frames = self.config.get("overlap_frames", 8) if is_narrative else 0
 
 
 
 
205
 
206
  temp_latent_paths = []
207
  overlap_condition_item = None
@@ -210,14 +199,16 @@ class VideoService:
210
  for i, chunk_prompt in enumerate(prompt_list):
211
  logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
212
 
213
- if i == num_chunks - 1:
214
- processed_frames = (num_chunks - 1) * frames_per_chunk
215
- current_frames = total_frames - processed_frames
216
- else:
217
- current_frames = frames_per_chunk
218
-
219
- if i > 0: current_frames += overlap_frames
220
 
 
 
 
 
221
  current_conditions = kwargs.get("initial_conditions", []) if i == 0 else []
222
  if overlap_condition_item: current_conditions.append(overlap_condition_item)
223
 
@@ -231,7 +222,8 @@ class VideoService:
231
  overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
232
  overlap_condition_item = ConditioningItem(media_item=overlap_latents, media_frame_number=0, conditioning_strength=1.0)
233
 
234
- if i > 0: chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
 
235
 
236
  chunk_path = RESULTS_DIR / f"temp_chunk_{i}_{used_seed}.pt"
237
  torch.save(chunk_latents.cpu(), chunk_path)
@@ -359,12 +351,16 @@ class VideoService:
359
  else: self.runtime_autocast_dtype = torch.float32
360
  logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}")
361
 
362
- def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT) -> int:
 
 
 
363
  return ((dim - 1) // alignment + 1) * alignment
364
 
365
  def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
366
  num_frames = int(round(duration_s * DEFAULT_FPS))
367
- aligned_frames = self._align(num_frames)
 
368
  return max(aligned_frames, min_frames)
369
 
370
  def _get_random_seed(self) -> int:
 
1
  # FILE: api/ltx_server_refactored_complete.py
2
+ # DESCRIPTION: Final orchestrator for LTX-Video generation.
3
+ # This version includes the fix for the narrative generation overlap bug and
4
+ # consolidates all previous refactoring and debugging improvements.
5
 
6
  import gc
7
  import json
 
36
  RESULTS_DIR = Path("/app/output")
37
  DEFAULT_FPS = 24.0
38
  FRAMES_ALIGNMENT = 8
39
+ LTX_REPO_ID = "Lightricks/LTX-Video"
40
 
41
  # Garante que a biblioteca LTX-Video seja importável
42
  def add_deps_to_path():
 
45
  sys.path.insert(0, repo_path)
46
  logging.info(f"[ltx_server] LTX-Video repository added to sys.path: {repo_path}")
47
 
48
+ add_deps_to_path()
49
 
50
  # --- Módulos da nossa Arquitetura ---
51
  try:
 
99
  logging.info(f"LTX allocated to devices: Main='{target_main_device_str}', VAE='{target_vae_device_str}'")
100
 
101
  self.config = self._load_config()
102
+ self._resolve_model_paths_from_cache()
103
 
104
  self.pipeline, self.latent_upsampler = build_ltx_pipeline_on_cpu(self.config)
105
 
 
119
  return yaml.safe_load(file)
120
 
121
  def _resolve_model_paths_from_cache(self):
122
+ """Finds the absolute paths to model files in the cache and updates the in-memory config."""
 
 
 
123
  logging.info("Resolving model paths from Hugging Face cache...")
124
  cache_dir = os.environ.get("HF_HOME")
125
  try:
126
+ main_ckpt_path = hf_hub_download(repo_id=LTX_REPO_ID, filename=self.config["checkpoint_path"], cache_dir=cache_dir)
 
 
 
 
 
 
127
  self.config["checkpoint_path"] = main_ckpt_path
128
  logging.info(f" -> Main checkpoint resolved to: {main_ckpt_path}")
129
 
 
130
  if self.config.get("spatial_upscaler_model_path"):
131
+ upscaler_path = hf_hub_download(repo_id=LTX_REPO_ID, filename=self.config["spatial_upscaler_model_path"], cache_dir=cache_dir)
 
 
 
 
 
132
  self.config["spatial_upscaler_model_path"] = upscaler_path
133
  logging.info(f" -> Spatial upscaler resolved to: {upscaler_path}")
134
  except Exception as e:
 
186
  num_chunks = len(prompt_list)
187
  total_frames = self._calculate_aligned_frames(kwargs.get("duration", 4.0))
188
  frames_per_chunk = max(FRAMES_ALIGNMENT, (total_frames // num_chunks // FRAMES_ALIGNMENT) * FRAMES_ALIGNMENT)
189
+
190
+ # Overlap must be N*8+1 frames. 9 is the smallest practical value.
191
+ overlap_frames = 9 if is_narrative else 0
192
+ if is_narrative:
193
+ logging.info(f"Narrative mode: Using overlap of {overlap_frames} frames between chunks.")
194
 
195
  temp_latent_paths = []
196
  overlap_condition_item = None
 
199
  for i, chunk_prompt in enumerate(prompt_list):
200
  logging.info(f"Processing scene {i+1}/{num_chunks}: '{chunk_prompt[:50]}...'")
201
 
202
+ if i < num_chunks - 1:
203
+ current_frames_base = frames_per_chunk
204
+ else: # Last chunk takes all remaining frames
205
+ processed_frames_base = (num_chunks - 1) * frames_per_chunk
206
+ current_frames_base = total_frames - processed_frames_base
 
 
207
 
208
+ current_frames = current_frames_base + (overlap_frames if i > 0 else 0)
209
+ # Ensure final frame count for generation is N*8+1
210
+ current_frames = self._align(current_frames, alignment_rule='n*8+1')
211
+
212
  current_conditions = kwargs.get("initial_conditions", []) if i == 0 else []
213
  if overlap_condition_item: current_conditions.append(overlap_condition_item)
214
 
 
222
  overlap_latents = chunk_latents[:, :, -overlap_frames:, :, :].clone()
223
  overlap_condition_item = ConditioningItem(media_item=overlap_latents, media_frame_number=0, conditioning_strength=1.0)
224
 
225
+ if i > 0:
226
+ chunk_latents = chunk_latents[:, :, overlap_frames:, :, :]
227
 
228
  chunk_path = RESULTS_DIR / f"temp_chunk_{i}_{used_seed}.pt"
229
  torch.save(chunk_latents.cpu(), chunk_path)
 
351
  else: self.runtime_autocast_dtype = torch.float32
352
  logging.info(f"Runtime precision policy set for autocast: {self.runtime_autocast_dtype}")
353
 
354
+ def _align(self, dim: int, alignment: int = FRAMES_ALIGNMENT, alignment_rule: str = 'default') -> int:
355
+ """Aligns a dimension to the nearest multiple of `alignment`."""
356
+ if alignment_rule == 'n*8+1':
357
+ return ((dim - 1) // alignment) * alignment + 1
358
  return ((dim - 1) // alignment + 1) * alignment
359
 
360
  def _calculate_aligned_frames(self, duration_s: float, min_frames: int = 1) -> int:
361
  num_frames = int(round(duration_s * DEFAULT_FPS))
362
+ # Para a duração total, sempre arredondamos para cima para o múltiplo de 8 mais próximo
363
+ aligned_frames = self._align(num_frames, alignment=FRAMES_ALIGNMENT)
364
  return max(aligned_frames, min_frames)
365
 
366
  def _get_random_seed(self) -> int: