eeuuia commited on
Commit
61634bf
·
verified ·
1 Parent(s): 3fa142d

Upload 2 files

Browse files
Files changed (2) hide show
  1. api/ltx_server_refactored.py +769 -0
  2. api/seedvr_server.py +277 -0
api/ltx_server_refactored.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ltx_server.py — VideoService (beta 1.1)
2
+ # Sempre output_type="latent"; no final: VAE (bloco inteiro) → pixels → MP4.
3
+ # Ignora UserWarning/FutureWarning e injeta VAE no manager com dtype/device corretos.
4
+ # --- 0. WARNINGS E AMBIENTE ---
5
+
6
+ import warnings
7
+ warnings.filterwarnings("ignore", category=UserWarning)
8
+ warnings.filterwarnings("ignore", category=FutureWarning)
9
+ warnings.filterwarnings("ignore", message=".*")
10
+ from huggingface_hub import logging
11
+ logging.set_verbosity_error()
12
+ logging.set_verbosity_warning()
13
+ logging.set_verbosity_info()
14
+ logging.set_verbosity_debug()
15
+ LTXV_DEBUG=1
16
+ LTXV_FRAME_LOG_EVERY=8
17
+ import os, subprocess, shlex, tempfile
18
+ import torch
19
+ import json
20
+ import numpy as np
21
+ import random
22
+ import os
23
+ import shlex
24
+ import yaml
25
+ from typing import List, Dict
26
+ from pathlib import Path
27
+ import imageio
28
+ from PIL import Image # Import adicionado para handle_media_upload_for_dims
29
+ import tempfile
30
+ from huggingface_hub import hf_hub_download
31
+ import sys
32
+ import subprocess
33
+ import gc
34
+ import shutil
35
+ import contextlib
36
+ import time
37
+ import traceback
38
+ from einops import rearrange
39
+ import torch.nn.functional as F
40
+ from managers.vae_manager import vae_manager_singleton
41
+ from tools.video_encode_tool import video_encode_tool_singleton
42
+ DEPS_DIR = Path("/data")
43
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
44
+
45
+ # CORREÇÃO: Movido run_setup para o início para garantir que seja definido antes de ser chamado.
46
+ def run_setup():
47
+ setup_script_path = "setup.py"
48
+ if not os.path.exists(setup_script_path):
49
+ print("[DEBUG] 'setup.py' não encontrado. Pulando clonagem de dependências.")
50
+ return
51
+ try:
52
+ print("[DEBUG] Executando setup.py para dependências...")
53
+ subprocess.run([sys.executable, setup_script_path], check=True)
54
+ print("[DEBUG] Setup concluído com sucesso.")
55
+ except subprocess.CalledProcessError as e:
56
+ print(f"[DEBUG] ERRO no setup.py (code {e.returncode}). Abortando.")
57
+ sys.exit(1)
58
+
59
+ if not LTX_VIDEO_REPO_DIR.exists():
60
+ print(f"[DEBUG] Repositório não encontrado em {LTX_VIDEO_REPO_DIR}. Rodando setup...")
61
+ run_setup()
62
+
63
+ def add_deps_to_path():
64
+ repo_path = str(LTX_VIDEO_REPO_DIR.resolve())
65
+ if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
66
+ sys.path.insert(0, repo_path)
67
+ print(f"[DEBUG] Repo adicionado ao sys.path: {repo_path}")
68
+ def _query_gpu_processes_via_nvml(device_index: int) -> List[Dict]:
69
+ try:
70
+ import psutil
71
+ import pynvml as nvml
72
+ nvml.nvmlInit()
73
+ handle = nvml.nvmlDeviceGetHandleByIndex(device_index)
74
+ try:
75
+ procs = nvml.nvmlDeviceGetComputeRunningProcesses_v3(handle)
76
+ except Exception:
77
+ procs = nvml.nvmlDeviceGetComputeRunningProcesses(handle)
78
+ results = []
79
+ for p in procs:
80
+ pid = int(p.pid)
81
+ used_mb = None
82
+ try:
83
+ if getattr(p, "usedGpuMemory", None) is not None and p.usedGpuMemory not in (0,):
84
+ used_mb = max(0, int(p.usedGpuMemory) // (1024 * 1024))
85
+ except Exception:
86
+ used_mb = None
87
+ name = "unknown"
88
+ user = "unknown"
89
+ try:
90
+ import psutil
91
+ pr = psutil.Process(pid)
92
+ name = pr.name()
93
+ user = pr.username()
94
+ except Exception:
95
+ pass
96
+ results.append({"pid": pid, "name": name, "user": user, "used_mb": used_mb})
97
+ nvml.nvmlShutdown()
98
+ return results
99
+ except Exception:
100
+ return []
101
+ def _query_gpu_processes_via_nvidiasmi(device_index: int) -> List[Dict]:
102
+ cmd = f"nvidia-smi -i {device_index} --query-compute-apps=pid,process_name,used_memory --format=csv,noheader,nounits"
103
+ try:
104
+ out = subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT, text=True, timeout=2.0)
105
+ except Exception:
106
+ return []
107
+ results = []
108
+ for line in out.strip().splitlines():
109
+ parts = [p.strip() for p in line.split(",")]
110
+ if len(parts) >= 3:
111
+ try:
112
+ pid = int(parts[0]); name = parts[1]; used_mb = int(parts[2])
113
+ user = "unknown"
114
+ try:
115
+ import psutil
116
+ pr = psutil.Process(pid)
117
+ user = pr.username()
118
+ except Exception:
119
+ pass
120
+ results.append({"pid": pid, "name": name, "user": user, "used_mb": used_mb})
121
+ except Exception:
122
+ continue
123
+ return results
124
+ def calculate_new_dimensions(orig_w, orig_h, divisor=8):
125
+ if orig_w == 0 or orig_h == 0:
126
+ return 512, 512
127
+ if orig_w >= orig_h:
128
+ aspect_ratio = orig_w / orig_h
129
+ new_h = 512
130
+ new_w = new_h * aspect_ratio
131
+ else:
132
+ aspect_ratio = orig_h / orig_w
133
+ new_w = 512
134
+ new_h = new_w * aspect_ratio
135
+ final_w = int(round(new_w / divisor)) * divisor
136
+ final_h = int(round(new_h / divisor)) * divisor
137
+ final_w = max(divisor, final_w)
138
+ final_h = max(divisor, final_h)
139
+ print(f"[Dimension Calc] Original: {orig_w}x{orig_h} -> Calculado: {new_w:.0f}x{new_h:.0f} -> Final (divisível por {divisor}): {final_w}x{final_h}")
140
+ return final_h, final_w
141
+ def handle_media_upload_for_dims(filepath, current_h, current_w):
142
+ # CORREÇÃO: Gradio (`gr`) não deve ser usado no backend. Retornando tupla diretamente.
143
+ if not filepath or not os.path.exists(str(filepath)):
144
+ return current_h, current_w
145
+ try:
146
+ if str(filepath).lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
147
+ with Image.open(filepath) as img:
148
+ orig_w, orig_h = img.size
149
+ else:
150
+ with imageio.get_reader(filepath) as reader:
151
+ meta = reader.get_meta_data()
152
+ orig_w, orig_h = meta.get('size', (current_w, current_h))
153
+ new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
154
+ return new_h, new_w
155
+ except Exception as e:
156
+ print(f"Erro ao processar mídia para dimensões: {e}")
157
+ return current_h, current_w
158
+ def _gpu_process_table(processes: List[Dict], current_pid: int) -> str:
159
+ if not processes:
160
+ return " - Processos ativos: (nenhum)\n"
161
+ processes = sorted(processes, key=lambda x: (x.get("used_mb") or 0), reverse=True)
162
+ lines = [" - Processos ativos (PID | USER | NAME | VRAM MB):"]
163
+ for p in processes:
164
+ star = "*" if p["pid"] == current_pid else " "
165
+ used_str = str(p["used_mb"]) if p.get("used_mb") is not None else "N/A"
166
+ lines.append(f" {star} {p['pid']} | {p['user']} | {p['name']} | {used_str}")
167
+ return "\n".join(lines) + "\n"
168
+ def log_tensor_info(tensor, name="Tensor"):
169
+ if not isinstance(tensor, torch.Tensor):
170
+ print(f"\n[INFO] '{name}' não é tensor.")
171
+ return
172
+ print(f"\n--- Tensor: {name} ---")
173
+ print(f" - Shape: {tuple(tensor.shape)}")
174
+ print(f" - Dtype: {tensor.dtype}")
175
+ print(f" - Device: {tensor.device}")
176
+ if tensor.numel() > 0:
177
+ try:
178
+ print(f" - Min: {tensor.min().item():.4f} Max: {tensor.max().item():.4f} Mean: {tensor.mean().item():.4f}")
179
+ except Exception:
180
+ pass
181
+ print("------------------------------------------\n")
182
+ add_deps_to_path()
183
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline
184
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
185
+ from ltx_video.models.autoencoders.vae_encode import un_normalize_latents, normalize_latents
186
+ from ltx_video.pipelines.pipeline_ltx_video import adain_filter_latent
187
+ from api.ltx.inference import (
188
+ create_ltx_video_pipeline,
189
+ create_latent_upsampler,
190
+ load_image_to_tensor_with_resize_and_crop,
191
+ seed_everething,
192
+ calculate_padding,
193
+ load_media_file,
194
+ )
195
+ class VideoService:
196
+ def __init__(self):
197
+ t0 = time.perf_counter()
198
+ print("[DEBUG] Inicializando VideoService...")
199
+ self.debug = os.getenv("LTXV_DEBUG", "1") == "1"
200
+ self.frame_log_every = int(os.getenv("LTXV_FRAME_LOG_EVERY", "8"))
201
+ self.config = self._load_config()
202
+ print(f"[DEBUG] Config carregada (precision={self.config.get('precision')}, sampler={self.config.get('sampler')})")
203
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
204
+ print(f"[DEBUG] Device selecionado: {self.device}")
205
+ self.last_memory_reserved_mb = 0.0
206
+ self._tmp_dirs = set(); self._tmp_files = set(); self._last_outputs = []
207
+
208
+ self.pipeline, self.latent_upsampler = self._load_models()
209
+ print(f"[DEBUG] Pipeline e Upsampler carregados. Upsampler ativo? {bool(self.latent_upsampler)}")
210
+
211
+ print(f"[DEBUG] Movendo modelos para {self.device}...")
212
+ self.pipeline.to(self.device)
213
+ if self.latent_upsampler:
214
+ self.latent_upsampler.to(self.device)
215
+
216
+ self._apply_precision_policy()
217
+ print(f"[DEBUG] runtime_autocast_dtype = {getattr(self, 'runtime_autocast_dtype', None)}")
218
+
219
+ vae_manager_singleton.attach_pipeline(
220
+ self.pipeline,
221
+ device=self.device,
222
+ autocast_dtype=self.runtime_autocast_dtype
223
+ )
224
+ print(f"[DEBUG] VAE manager conectado: has_vae={hasattr(self.pipeline, 'vae')} device={self.device}")
225
+
226
+ if self.device == "cuda":
227
+ torch.cuda.empty_cache()
228
+ self._log_gpu_memory("Após carregar modelos")
229
+
230
+ print(f"[DEBUG] VideoService pronto. boot_time={time.perf_counter()-t0:.3f}s")
231
+
232
+ def _log_gpu_memory(self, stage_name: str):
233
+ if self.device != "cuda":
234
+ return
235
+ device_index = torch.cuda.current_device() if torch.cuda.is_available() else 0
236
+ current_reserved_b = torch.cuda.memory_reserved(device_index)
237
+ current_reserved_mb = current_reserved_b / (1024 ** 2)
238
+ total_memory_b = torch.cuda.get_device_properties(device_index).total_memory
239
+ total_memory_mb = total_memory_b / (1024 ** 2)
240
+ peak_reserved_mb = torch.cuda.max_memory_reserved(device_index) / (1024 ** 2)
241
+ delta_mb = current_reserved_mb - getattr(self, "last_memory_reserved_mb", 0.0)
242
+ processes = _query_gpu_processes_via_nvml(device_index) or _query_gpu_processes_via_nvidiasmi(device_index)
243
+ print(f"\n--- [LOG GPU] {stage_name} (cuda:{device_index}) ---")
244
+ print(f" - Reservado: {current_reserved_mb:.2f} MB / {total_memory_mb:.2f} MB (Δ={delta_mb:+.2f} MB)")
245
+ if peak_reserved_mb > getattr(self, "last_memory_reserved_mb", 0.0):
246
+ print(f" - Pico reservado (nesta fase): {peak_reserved_mb:.2f} MB")
247
+ print(_gpu_process_table(processes, os.getpid()), end="")
248
+ print("--------------------------------------------------\n")
249
+ self.last_memory_reserved_mb = current_reserved_mb
250
+
251
+ def _register_tmp_dir(self, d: str):
252
+ if d and os.path.isdir(d):
253
+ self._tmp_dirs.add(d); print(f"[DEBUG] Registrado tmp dir: {d}")
254
+
255
+ def _register_tmp_file(self, f: str):
256
+ if f and os.path.exists(f):
257
+ self._tmp_files.add(f); print(f"[DEBUG] Registrado tmp file: {f}")
258
+
259
+ def finalize(self, keep_paths=None, extra_paths=None, clear_gpu=True):
260
+ print("[DEBUG] Finalize: iniciando limpeza...")
261
+ keep = set(keep_paths or []); extras = set(extra_paths or [])
262
+ removed_files = 0
263
+ for f in list(self._tmp_files | extras):
264
+ try:
265
+ if f not in keep and os.path.isfile(f):
266
+ os.remove(f); removed_files += 1; print(f"[DEBUG] Removido arquivo tmp: {f}")
267
+ except Exception as e:
268
+ print(f"[DEBUG] Falha removendo arquivo {f}: {e}")
269
+ finally:
270
+ self._tmp_files.discard(f)
271
+ removed_dirs = 0
272
+ for d in list(self._tmp_dirs):
273
+ try:
274
+ if d not in keep and os.path.isdir(d):
275
+ shutil.rmtree(d, ignore_errors=True); removed_dirs += 1; print(f"[DEBUG] Removido diretório tmp: {d}")
276
+ except Exception as e:
277
+ print(f"[DEBUG] Falha removendo diretório {d}: {e}")
278
+ finally:
279
+ self._tmp_dirs.discard(d)
280
+ print(f"[DEBUG] Finalize: arquivos removidos={removed_files}, dirs removidos={removed_dirs}")
281
+ gc.collect()
282
+ try:
283
+ if clear_gpu and torch.cuda.is_available():
284
+ torch.cuda.empty_cache()
285
+ try:
286
+ torch.cuda.ipc_collect()
287
+ except Exception:
288
+ pass
289
+ except Exception as e:
290
+ print(f"[DEBUG] Finalize: limpeza GPU falhou: {e}")
291
+ try:
292
+ self._log_gpu_memory("Após finalize")
293
+ except Exception as e:
294
+ print(f"[DEBUG] Log GPU pós-finalize falhou: {e}")
295
+
296
+ def _load_config(self):
297
+ base = LTX_VIDEO_REPO_DIR / "configs"
298
+ candidates = [
299
+ base / "ltxv-13b-0.9.8-dev-fp8.yaml",
300
+ base / "ltxv-13b-0.9.8-distilled-fp8.yaml",
301
+ base / "ltxv-13b-0.9.8-distilled.yaml",
302
+ ]
303
+ for cfg in candidates:
304
+ if cfg.exists():
305
+ print(f"[DEBUG] Config selecionada: {cfg}")
306
+ with open(cfg, "r") as file:
307
+ return yaml.safe_load(file)
308
+ cfg = base / "ltxv-13b-0.9.8-distilled-fp8.yaml"
309
+ print(f"[DEBUG] Config fallback: {cfg}")
310
+ with open(cfg, "r") as file:
311
+ return yaml.safe_load(file)
312
+
313
+ def _load_models(self):
314
+ t0 = time.perf_counter()
315
+ LTX_REPO = "Lightricks/LTX-Video"
316
+ print("[DEBUG] Baixando checkpoint principal...")
317
+ distilled_model_path = hf_hub_download(
318
+ repo_id=LTX_REPO,
319
+ filename=self.config["checkpoint_path"],
320
+ local_dir=os.getenv("HF_HOME"),
321
+ cache_dir=os.getenv("HF_HOME_CACHE"),
322
+ token=os.getenv("HF_TOKEN"),
323
+ )
324
+ self.config["checkpoint_path"] = distilled_model_path
325
+ print(f"[DEBUG] Checkpoint em: {distilled_model_path}")
326
+
327
+ print("[DEBUG] Baixando upscaler espacial...")
328
+ spatial_upscaler_path = hf_hub_download(
329
+ repo_id=LTX_REPO,
330
+ filename=self.config["spatial_upscaler_model_path"],
331
+ local_dir=os.getenv("HF_HOME"),
332
+ cache_dir=os.getenv("HF_HOME_CACHE"),
333
+ token=os.getenv("HF_TOKEN")
334
+ )
335
+ self.config["spatial_upscaler_model_path"] = spatial_upscaler_path
336
+ print(f"[DEBUG] Upscaler em: {spatial_upscaler_path}")
337
+
338
+ print("[DEBUG] Construindo pipeline...")
339
+ pipeline = create_ltx_video_pipeline(
340
+ ckpt_path=self.config["checkpoint_path"],
341
+ precision=self.config["precision"],
342
+ text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
343
+ sampler=self.config["sampler"],
344
+ device="cpu",
345
+ enhance_prompt=False,
346
+ prompt_enhancer_image_caption_model_name_or_path=self.config["prompt_enhancer_image_caption_model_name_or_path"],
347
+ prompt_enhancer_llm_model_name_or_path=self.config["prompt_enhancer_llm_model_name_or_path"],
348
+ )
349
+ print("[DEBUG] Pipeline pronto.")
350
+
351
+ latent_upsampler = None
352
+ if self.config.get("spatial_upscaler_model_path"):
353
+ print("[DEBUG] Construindo latent_upsampler...")
354
+ latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
355
+ print("[DEBUG] Upsampler pronto.")
356
+ print(f"[DEBUG] _load_models() tempo total={time.perf_counter()-t0:.3f}s")
357
+ return pipeline, latent_upsampler
358
+
359
+ def _promote_fp8_weights_to_bf16(self, module):
360
+ if not isinstance(module, torch.nn.Module):
361
+ print("[DEBUG] Promoção FP8→BF16 ignorada: alvo não é nn.Module.")
362
+ return
363
+ f8 = getattr(torch, "float8_e4m3fn", None)
364
+ if f8 is None:
365
+ print("[DEBUG] torch.float8_e4m3fn indisponível.")
366
+ return
367
+ p_cnt = b_cnt = 0
368
+ for _, p in module.named_parameters(recurse=True):
369
+ try:
370
+ if p.dtype == f8:
371
+ with torch.no_grad():
372
+ p.data = p.data.to(torch.bfloat16); p_cnt += 1
373
+ except Exception:
374
+ pass
375
+ for _, b in module.named_buffers(recurse=True):
376
+ try:
377
+ if hasattr(b, "dtype") and b.dtype == f8:
378
+ b.data = b.data.to(torch.bfloat16); b_cnt += 1
379
+ except Exception:
380
+ pass
381
+ print(f"[DEBUG] FP8→BF16: params_promoted={p_cnt}, buffers_promoted={b_cnt}")
382
+
383
+ @torch.no_grad()
384
+ def _upsample_latents_internal(self, latents: torch.Tensor) -> torch.Tensor:
385
+ if not self.latent_upsampler:
386
+ raise ValueError("Latent Upsampler não está carregado.")
387
+ self.latent_upsampler.to(self.device)
388
+ self.pipeline.vae.to(self.device)
389
+ print(f"[DEBUG-UPSAMPLE] Shape de entrada: {tuple(latents.shape)}")
390
+ latents = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
391
+ upsampled_latents = self.latent_upsampler(latents)
392
+ upsampled_latents = normalize_latents(upsampled_latents, self.pipeline.vae, vae_per_channel_normalize=True)
393
+ print(f"[DEBUG-UPSAMPLE] Shape de saída: {tuple(upsampled_latents.shape)}")
394
+ return upsampled_latents
395
+
396
+ def _apply_precision_policy(self):
397
+ prec = str(self.config.get("precision", "")).lower()
398
+ self.runtime_autocast_dtype = torch.float32
399
+ print(f"[DEBUG] Aplicando política de precisão: {prec}")
400
+ if prec == "float8_e4m3fn":
401
+ self.runtime_autocast_dtype = torch.bfloat16
402
+ force_promote = os.getenv("LTXV_FORCE_BF16_ON_FP8", "0") == "1"
403
+ print(f"[DEBUG] FP8 detectado. force_promote={force_promote}")
404
+ if force_promote and hasattr(torch, "float8_e4m3fn"):
405
+ try:
406
+ self._promote_fp8_weights_to_bf16(self.pipeline)
407
+ except Exception as e:
408
+ print(f"[DEBUG] Promoção FP8→BF16 na pipeline falhou: {e}")
409
+ try:
410
+ if self.latent_upsampler:
411
+ self._promote_fp8_weights_to_bf16(self.latent_upsampler)
412
+ except Exception as e:
413
+ print(f"[DEBUG] Promoção FP8→BF16 no upsampler falhou: {e}")
414
+ elif prec == "bfloat16":
415
+ self.runtime_autocast_dtype = torch.bfloat16
416
+ elif prec == "mixed_precision":
417
+ self.runtime_autocast_dtype = torch.float16
418
+ else:
419
+ self.runtime_autocast_dtype = torch.float32
420
+
421
+ def _prepare_conditioning_tensor(self, filepath, height, width, padding_values):
422
+ print(f"[DEBUG] Carregando condicionamento: {filepath}")
423
+ tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
424
+ tensor = torch.nn.functional.pad(tensor, padding_values)
425
+ out = tensor.to(self.device, dtype=self.runtime_autocast_dtype) if self.device == "cuda" else tensor.to(self.device)
426
+ print(f"[DEBUG] Cond shape={tuple(out.shape)} dtype={out.dtype} device={out.device}")
427
+ return out
428
+
429
+ def _dividir_latentes_por_tamanho(self, latents_brutos, num_latente_por_chunk: int, overlap: int = 1):
430
+ sum_latent = latents_brutos.shape[2]
431
+ chunks = []
432
+ if num_latente_por_chunk >= sum_latent:
433
+ return [latents_brutos.clone().detach()] # CORREÇÃO: Retornar uma lista e clonar
434
+ # CORREÇÃO: Lógica de chunking simplificada e corrigida para evitar estouro de índice
435
+ start = 0
436
+ while start < sum_latent:
437
+ end = min(start + num_latente_por_chunk, sum_latent)
438
+ # Para o overlap, pegamos um pouco do chunk anterior, exceto para o primeiro
439
+ overlap_start = max(0, start - overlap)
440
+
441
+ # O chunk a ser processado vai de `overlap_start` até `end`
442
+ # mas o chunk "real" para junção posterior seria de `start` a `end`
443
+ # A lógica atual já faz um overlap simples, vamos refinar
444
+ effective_end = min(start + num_latente_por_chunk, sum_latent)
445
+ chunk = latents_brutos[:, :, start:effective_end, :, :].clone().detach()
446
+
447
+ # Adiciona overlap no final se não for o último chunk
448
+ if effective_end < sum_latent:
449
+ overlap_end = min(effective_end + overlap, sum_latent)
450
+ chunk = latents_brutos[:, :, start:overlap_end, :, :].clone().detach()
451
+
452
+ print(f"[DEBUG] Chunk: start={start}, end={chunk.shape[2]}, total_latents={sum_latent}")
453
+ chunks.append(chunk)
454
+
455
+ # Avança para o próximo chunk
456
+ if start + num_latente_por_chunk >= sum_latent:
457
+ break
458
+ start += num_latente_por_chunk
459
+
460
+ return chunks
461
+
462
+ def _get_total_frames(self, video_path: str) -> int:
463
+ cmd = [
464
+ "ffprobe", "-v", "error", "-select_streams", "v:0", "-count_frames",
465
+ "-show_entries", "stream=nb_read_frames", "-of", "default=nokey=1:noprint_wrappers=1", video_path
466
+ ]
467
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
468
+ return int(result.stdout.strip())
469
+
470
+ def _gerar_lista_com_transicoes(self, pasta: str, video_paths: list[str], crossfade_frames: int = 8) -> list[str]:
471
+ # Esta função parece complexa e propensa a erros com nomes de arquivo.
472
+ # Por segurança, mantendo a lógica original, mas corrigindo possíveis bugs de `shell=True`
473
+ # e garantindo que os arquivos existam.
474
+ if len(video_paths) <= 1:
475
+ return video_paths # Não há o que fazer
476
+
477
+ nova_lista_intermediaria = []
478
+ # Primeiro, cria todos os vídeos podados
479
+ videos_podados = []
480
+ for i, base in enumerate(video_paths):
481
+ video_podado = os.path.join(pasta, f"podado_{i}.mp4")
482
+ total_frames = self._get_total_frames(base)
483
+
484
+ start_frame = crossfade_frames if i > 0 else 0
485
+ end_frame = total_frames - crossfade_frames if i < len(video_paths) - 1 else total_frames
486
+
487
+ # Pular poda se não houver frames suficientes
488
+ if start_frame >= end_frame:
489
+ continue
490
+
491
+ cmd = [
492
+ 'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error', '-i', base,
493
+ '-vf', f'trim=start_frame={start_frame}:end_frame={end_frame},setpts=PTS-STARTPTS',
494
+ '-an', video_podado
495
+ ]
496
+ subprocess.run(cmd, check=True)
497
+ videos_podados.append(video_podado)
498
+
499
+ # Agora, cria as transições e monta a lista final
500
+ lista_final = [videos_podados[0]]
501
+ for i in range(len(video_paths) - 1):
502
+ video_anterior = video_paths[i]
503
+ video_seguinte = video_paths[i+1]
504
+
505
+ # Extrai fade_fim do anterior
506
+ fade_fim_path = os.path.join(pasta, f"fade_fim_{i}.mp4")
507
+ total_frames_anterior = self._get_total_frames(video_anterior)
508
+ cmd_fim = [
509
+ 'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error', '-i', video_anterior,
510
+ '-vf', f'trim=start_frame={total_frames_anterior - crossfade_frames},setpts=PTS-STARTPTS',
511
+ '-an', fade_fim_path
512
+ ]
513
+ subprocess.run(cmd_fim, check=True)
514
+
515
+ # Extrai fade_ini do seguinte
516
+ fade_ini_path = os.path.join(pasta, f"fade_ini_{i+1}.mp4")
517
+ cmd_ini = [
518
+ 'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error', '-i', video_seguinte,
519
+ '-vf', f'trim=end_frame={crossfade_frames},setpts=PTS-STARTPTS', '-an', fade_ini_path
520
+ ]
521
+ subprocess.run(cmd_ini, check=True)
522
+
523
+ # Cria a transição
524
+ transicao_path = os.path.join(pasta, f"transicao_{i}_{i+1}.mp4")
525
+ cmd_blend = [
526
+ 'ffmpeg', '-y', '-hide_banner', '-loglevel', 'error',
527
+ '-i', fade_fim_path, '-i', fade_ini_path,
528
+ '-filter_complex', f'[0:v][1:v]blend=all_expr=\'A*(1-T/{crossfade_frames})+B*(T/{crossfade_frames})\',format=yuv420p',
529
+ '-frames:v', str(crossfade_frames), transicao_path
530
+ ]
531
+ subprocess.run(cmd_blend, check=True)
532
+
533
+ lista_final.append(transicao_path)
534
+ lista_final.append(videos_podados[i+1])
535
+
536
+ return lista_final
537
+
538
+ def _concat_mp4s_no_reencode(self, mp4_list: List[str], out_path: str):
539
+ if not mp4_list:
540
+ raise ValueError("A lista de MP4s para concatenar está vazia.")
541
+ # Se houver apenas um vídeo, apenas o copie/mova
542
+ if len(mp4_list) == 1:
543
+ shutil.move(mp4_list[0], out_path)
544
+ print(f"[DEBUG] Apenas um vídeo, movido para: {out_path}")
545
+ return
546
+
547
+ with tempfile.NamedTemporaryFile("w", delete=False, suffix=".txt") as f:
548
+ for mp4 in mp4_list:
549
+ f.write(f"file '{os.path.abspath(mp4)}'\n")
550
+ list_path = f.name
551
+
552
+ cmd = f"ffmpeg -y -f concat -safe 0 -i {list_path} -c copy {out_path}"
553
+ print(f"[DEBUG] Concat: {cmd}")
554
+
555
+ try:
556
+ subprocess.check_call(shlex.split(cmd))
557
+ finally:
558
+ try:
559
+ os.remove(list_path)
560
+ except Exception:
561
+ pass
562
+
563
+ def generate(
564
+ self,
565
+ prompt,
566
+ negative_prompt,
567
+ mode="text-to-video",
568
+ start_image_filepath=None,
569
+ middle_image_filepath=None,
570
+ middle_frame_number=None,
571
+ middle_image_weight=1.0,
572
+ end_image_filepath=None,
573
+ end_image_weight=1.0,
574
+ input_video_filepath=None,
575
+ height=512,
576
+ width=704,
577
+ duration=2.0,
578
+ frames_to_use=9, # Parâmetro não utilizado, mas mantido por consistência
579
+ seed=42,
580
+ randomize_seed=True,
581
+ guidance_scale=3.0,
582
+ improve_texture=True,
583
+ progress_callback=None,
584
+ external_decode=True, # Parâmetro não utilizado, mas mantido
585
+ ):
586
+ t_all = time.perf_counter()
587
+ print(f"[DEBUG] generate() begin mode={mode} improve_texture={improve_texture}")
588
+ if self.device == "cuda":
589
+ torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
590
+ self._log_gpu_memory("Início da Geração")
591
+
592
+ if mode == "image-to-video" and not start_image_filepath:
593
+ raise ValueError("A imagem de início é obrigatória para o modo image-to-video")
594
+ used_seed = random.randint(0, 2**32 - 1) if randomize_seed else int(seed)
595
+ seed_everething(used_seed); print(f"[DEBUG] Seed usado: {used_seed}")
596
+ FPS = 24.0; MAX_NUM_FRAMES = 2570
597
+ target_frames_rounded = round(duration * FPS)
598
+ n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
599
+ actual_num_frames = max(9, min(MAX_NUM_FRAMES, int(n_val * 8 + 1)))
600
+ height_padded = ((height - 1) // 8 + 1) * 8
601
+ width_padded = ((width - 1) // 8 + 1) * 8
602
+ padding_values = calculate_padding(height, width, height_padded, width_padded)
603
+ generator = torch.Generator(device=self.device).manual_seed(used_seed)
604
+
605
+ conditioning_items = []
606
+ if mode == "image-to-video":
607
+ start_tensor = self._prepare_conditioning_tensor(start_image_filepath, height, width, padding_values)
608
+ conditioning_items.append(ConditioningItem(start_tensor, 0, 1.0))
609
+ if middle_image_filepath and middle_frame_number is not None:
610
+ middle_tensor = self._prepare_conditioning_tensor(middle_image_filepath, height, width, padding_values)
611
+ safe_middle_frame = max(0, min(int(middle_frame_number), actual_num_frames - 1))
612
+ conditioning_items.append(ConditioningItem(middle_tensor, safe_middle_frame, float(middle_image_weight)))
613
+ if end_image_filepath:
614
+ end_tensor = self._prepare_conditioning_tensor(end_image_filepath, height, width, padding_values)
615
+ last_frame_index = actual_num_frames - 1
616
+ conditioning_items.append(ConditioningItem(end_tensor, last_frame_index, float(end_image_weight)))
617
+ print(f"[DEBUG] Conditioning items: {len(conditioning_items)}")
618
+
619
+ call_kwargs = {
620
+ "prompt": prompt, "negative_prompt": negative_prompt, "height": height_padded, "width": width_padded,
621
+ "num_frames": actual_num_frames, "frame_rate": int(FPS), "generator": generator, "output_type": "latent",
622
+ "conditioning_items": conditioning_items if conditioning_items else None, "media_items": None,
623
+ "decode_timestep": self.config["decode_timestep"], "decode_noise_scale": self.config["decode_noise_scale"],
624
+ "stochastic_sampling": self.config["stochastic_sampling"], "image_cond_noise_scale": 0.01, "is_video": True,
625
+ "vae_per_channel_normalize": True, "mixed_precision": (self.config["precision"] == "mixed_precision"),
626
+ "offload_to_cpu": False, "enhance_prompt": False, "skip_layer_strategy": SkipLayerStrategy.AttentionValues,
627
+ }
628
+
629
+ # CORREÇÃO: Inicialização de listas
630
+ latents_list = []
631
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_"); self._register_tmp_dir(temp_dir)
632
+ results_dir = "/app/output"; os.makedirs(results_dir, exist_ok=True)
633
+
634
+ try:
635
+ if improve_texture:
636
+ ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
637
+ with ctx:
638
+ if not self.latent_upsampler:
639
+ raise ValueError("Upscaler espacial não carregado, mas 'improve_texture' está ativo.")
640
+
641
+ print("\n--- INICIANDO ETAPA 1: GERAÇÃO BASE (FIRST PASS) ---")
642
+ t_pass1 = time.perf_counter()
643
+ first_pass_config = self.config.get("first_pass", {}).copy()
644
+ first_pass_config.pop("num_inference_steps", None)
645
+ downscale_factor = self.config.get("downscale_factor", 0.6666666)
646
+ vae_scale_factor = self.pipeline.vae_scale_factor
647
+ x_width = int(width_padded * downscale_factor)
648
+ downscaled_width = x_width - (x_width % vae_scale_factor)
649
+ x_height = int(height_padded * downscale_factor)
650
+ downscaled_height = x_height - (x_height % vae_scale_factor)
651
+ print(f"[DEBUG] First Pass Dims: Original Pad ({width_padded}x{height_padded}) -> Downscaled ({downscaled_width}x{downscaled_height})")
652
+
653
+ first_pass_kwargs = call_kwargs.copy()
654
+ first_pass_kwargs.update({
655
+ "output_type": "latent", "width": downscaled_width, "height": downscaled_height,
656
+ "guidance_scale": float(guidance_scale), **first_pass_config
657
+ })
658
+
659
+ print(f"[DEBUG] First Pass: Gerando em {downscaled_width}x{downscaled_height}...")
660
+ # CORREÇÃO: Usar self.pipeline, não a variável deletada 'pipeline'
661
+ latents = self.pipeline(**first_pass_kwargs).images
662
+ log_tensor_info(latents, "Latentes Base (First Pass)")
663
+ print(f"[DEBUG] First Pass concluída em {time.perf_counter() - t_pass1:.2f}s")
664
+
665
+ with ctx:
666
+ print("\n--- INICIANDO ETAPA 2: UPSCALE DOS LATENTES ---")
667
+ t_upscale = time.perf_counter()
668
+ upsampled_latents = self._upsample_latents_internal(latents)
669
+ upsampled_latents = adain_filter_latent(latents=upsampled_latents, reference_latents=latents)
670
+ print(f"[DEBUG] Upscale de Latentes concluído em {time.perf_counter() - t_upscale:.2f}s")
671
+
672
+ # CORREÇÃO: Manter latentes originais para AdaIN e passar latentes com upscale para o second pass
673
+ reference_latents_cpu = latents.detach().to("cpu", non_blocking=True)
674
+ latents_to_refine = upsampled_latents
675
+ del upsampled_latents; del latents; gc.collect(); torch.cuda.empty_cache()
676
+
677
+ # CORREÇÃO: Lógica de chunking para o second pass
678
+ latents_parts = self._dividir_latentes_por_tamanho(latents_to_refine, 32, 8) # Exemplo: chunks de 32 frames com 8 de overlap
679
+ del latents_to_refine
680
+
681
+ with ctx:
682
+ for i, latents_chunk in enumerate(latents_parts):
683
+ print(f"\n--- INICIANDO ETAPA 3.{i+1}: REFINAMENTO DE TEXTURA (SECOND PASS) ---")
684
+ # CORREÇÃO: AdaIN precisa de latents de referência com mesmo H/W, o que não é o caso aqui.
685
+ # Vamos aplicar AdaIN com o próprio chunk para normalização, ou pular. Pulando por simplicidade.
686
+
687
+ second_pass_config = self.config.get("second_pass", {}).copy()
688
+ second_pass_config.pop("num_inference_steps", None)
689
+
690
+ # O tamanho do second pass deve ser o tamanho do latente de entrada (após upscale)
691
+ second_pass_height, second_pass_width = latents_chunk.shape[3] * 8, latents_chunk.shape[4] * 8
692
+
693
+ print(f"[DEBUG] Second Pass Dims: Target ({second_pass_width}x{second_pass_height})")
694
+ t_pass2 = time.perf_counter()
695
+ second_pass_kwargs = call_kwargs.copy()
696
+ second_pass_kwargs.update({
697
+ "output_type": "latent", "width": second_pass_width, "height": second_pass_height,
698
+ "latents": latents_chunk.to(self.device), # Mover chunk para GPU
699
+ "guidance_scale": float(guidance_scale),
700
+ "num_frames": latents_chunk.shape[2], # Usar o número de frames do chunk
701
+ **second_pass_config
702
+ })
703
+ print(f"[DEBUG] Second Pass: Refinando chunk {i+1}/{len(latents_parts)}...")
704
+ final_latents = self.pipeline(**second_pass_kwargs).images
705
+ log_tensor_info(final_latents, "Latentes Finais (Pós-Second Pass)")
706
+ print(f"[DEBUG] Second part Pass concluída em {time.perf_counter() - t_pass2:.2f}s")
707
+ latents_cpu = final_latents.detach().to("cpu", non_blocking=True)
708
+ latents_list.append(latents_cpu)
709
+ del final_latents; del latents_chunk; gc.collect(); torch.cuda.empty_cache()
710
+ else:
711
+ ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype) if self.device == "cuda" else contextlib.nullcontext()
712
+ with ctx:
713
+ print("\n--- INICIANDO GERAÇÃO DE ETAPA ÚNICA ---")
714
+ t_single = time.perf_counter()
715
+ single_pass_call_kwargs = call_kwargs.copy()
716
+ # CORREÇÃO: `pipeline_instance` não existe, usar `self.pipeline`.
717
+ latents_single_pass = self.pipeline(**single_pass_call_kwargs).images
718
+ log_tensor_info(latents_single_pass, "Latentes Finais (Etapa Única)")
719
+ print(f"[DEBUG] Etapa única concluída em {time.perf_counter() - t_single:.2f}s")
720
+ latents_cpu = latents_single_pass.detach().to("cpu", non_blocking=True)
721
+ latents_list.append(latents_cpu) # CORREÇÃO: aqui deve ser latents_cpu, não latents_single_pass
722
+ del latents_single_pass; gc.collect(); torch.cuda.empty_cache()
723
+
724
+ # --- ETAPA FINAL: DECODIFICAÇÃO E CODIFICAÇÃO MP4 ---
725
+ print("\n--- INICIANDO ETAPA FINAL: DECODIFICAÇÃO E MONTAGEM ---")
726
+ partes_mp4 = []
727
+ for i, latents in enumerate(latents_list):
728
+ print(f"[DEBUG] Decodificando partição {i+1}/{len(latents_list)}: {tuple(latents.shape)}")
729
+ output_video_path = os.path.join(temp_dir, f"output_{used_seed}_{i}.mp4")
730
+
731
+ pixel_tensor = vae_manager_singleton.decode(
732
+ latents.to(self.device, non_blocking=True),
733
+ decode_timestep=float(self.config.get("decode_timestep", 0.05))
734
+ )
735
+ log_tensor_info(pixel_tensor, "Pixel tensor (VAE saída)")
736
+
737
+ video_encode_tool_singleton.save_video_from_tensor(
738
+ pixel_tensor, output_video_path, fps=call_kwargs["frame_rate"], progress_callback=progress_callback
739
+ )
740
+ partes_mp4.append(output_video_path)
741
+ del pixel_tensor; del latents; gc.collect(); torch.cuda.empty_cache()
742
+
743
+ final_vid = os.path.join(results_dir, f"final_video_{used_seed}.mp4")
744
+ if len(partes_mp4) > 1:
745
+ # A função _gerar_lista_com_transicoes é complexa, usando uma concatenação direta como fallback robusto.
746
+ # Para usar a transição, a lógica de overlap na divisão de latentes precisa ser perfeita.
747
+ print("[DEBUG] Múltiplas partes geradas, concatenando...")
748
+ partes_mp4_fade = self._gerar_lista_com_transicoes(pasta=temp_dir, video_paths=partes_mp4, crossfade_frames=8)
749
+ self._concat_mp4s_no_reencode(partes_mp4_fade, final_vid)
750
+ else:
751
+ shutil.move(partes_mp4[0], final_vid)
752
+
753
+ self._log_gpu_memory("Fim da Geração")
754
+ return final_vid, used_seed
755
+
756
+ except Exception as e:
757
+ print("[DEBUG] EXCEÇÃO NA GERAÇÃO:")
758
+ print("".join(traceback.format_exception(type(e), e, e.__traceback__)))
759
+ raise
760
+
761
+ finally:
762
+ gc.collect()
763
+ if torch.cuda.is_available():
764
+ torch.cuda.empty_cache()
765
+ torch.cuda.ipc_collect()
766
+ self.finalize(keep_paths=[]) # O resultado final já foi movido
767
+
768
+ print("Criando instância do VideoService. O carregamento do modelo começará agora...")
769
+ video_generation_service = VideoService()
api/seedvr_server.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/seedvr_server.py
2
+
3
+ import os
4
+ import sys
5
+ import time
6
+ import subprocess
7
+ import queue
8
+ import multiprocessing as mp
9
+ from pathlib import Path
10
+ from typing import Optional, Callable
11
+
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ # -------------------------------------------------------------
15
+ # 1. CONFIGURAÇÃO DE AMBIENTE E CUDA
16
+ # -------------------------------------------------------------
17
+
18
+ # Garante o uso seguro de CUDA com multiprocessing para estabilidade.
19
+ if mp.get_start_method(allow_none=True) != 'spawn':
20
+ mp.set_start_method('spawn', force=True)
21
+
22
+ # Configuração de alocação de memória da VRAM
23
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
24
+
25
+ # Adiciona dinamicamente o caminho do repositório clonado ao sys.path.
26
+ SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
27
+ if str(SEEDVR_REPO_PATH) not in sys.path:
28
+ sys.path.insert(0, str(SEEDVR_REPO_PATH))
29
+
30
+ # Importações pesadas (torch, etc.) são feitas após a configuração do ambiente.
31
+ import torch
32
+ import cv2
33
+ import numpy as np
34
+ from datetime import datetime
35
+
36
+ # -------------------------------------------------------------
37
+ # 2. FUNÇÕES AUXILIARES DE PROCESSAMENTO (Workers e I/O)
38
+ # -------------------------------------------------------------
39
+
40
+ def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load_cap=None):
41
+ """Extrai quadros de um vídeo e os converte para o formato de tensor."""
42
+ if debug: print(f"🎬 Extraindo frames de: {video_path}")
43
+ if not os.path.exists(video_path): raise FileNotFoundError(f"Arquivo de vídeo não encontrado: {video_path}")
44
+
45
+ cap = cv2.VideoCapture(video_path)
46
+ if not cap.isOpened(): raise ValueError(f"Não foi possível abrir o arquivo de vídeo: {video_path}")
47
+
48
+ fps = cap.get(cv2.CAP_PROP_FPS)
49
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
50
+ if debug: print(f"📊 Info do vídeo: {frame_count} frames, {fps:.2f} FPS")
51
+
52
+ frames = []
53
+ frames_loaded = 0
54
+ for i in range(frame_count):
55
+ ret, frame = cap.read()
56
+ if not ret: break
57
+ if i < skip_first_frames: continue
58
+ if load_cap and frames_loaded >= load_cap: break
59
+
60
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
61
+ frames.append(frame.astype(np.float32) / 255.0)
62
+ frames_loaded += 1
63
+ cap.release()
64
+
65
+ if not frames: raise ValueError(f"Nenhum frame foi extraído do vídeo: {video_path}")
66
+ if debug: print(f"✅ {len(frames)} frames extraídos com sucesso.")
67
+ return torch.from_numpy(np.stack(frames)).to(torch.float16), fps
68
+
69
+ def save_frames_to_video(frames_tensor, output_path, fps=30.0, debug=False):
70
+ """Salva um tensor de quadros em um arquivo de vídeo."""
71
+ if debug: print(f"🎬 Salvando {frames_tensor.shape[0]} frames em: {output_path}")
72
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
73
+
74
+ frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
75
+ T, H, W, _ = frames_np.shape
76
+
77
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
78
+ out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
79
+ if not out.isOpened(): raise ValueError(f"Não foi possível criar o arquivo de vídeo: {output_path}")
80
+
81
+ for frame in frames_np:
82
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
83
+ out.release()
84
+ if debug: print(f"✅ Vídeo salvo com sucesso: {output_path}")
85
+
86
+ def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
87
+ """Processo filho (worker) que executa o upscaling em uma GPU dedicada."""
88
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
89
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
90
+
91
+ import torch
92
+ from src.core.model_manager import configure_runner
93
+ from src.core.generation import generation_loop
94
+
95
+ try:
96
+ frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
97
+
98
+ callback = (lambda b, t, _, m: progress_queue.put((proc_idx, b, t, m))) if progress_queue else None
99
+
100
+ runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
101
+ result_tensor = generation_loop(
102
+ runner=runner, images=frames_tensor, cfg_scale=1.0, seed=shared_args["seed"],
103
+ res_w=shared_args["resolution"], batch_size=shared_args["batch_size"],
104
+ preserve_vram=shared_args["preserve_vram"], temporal_overlap=0,
105
+ debug=shared_args["debug"], progress_callback=callback
106
+ )
107
+ return_queue.put((proc_idx, result_tensor.cpu().numpy()))
108
+ except Exception as e:
109
+ import traceback
110
+ error_msg = f"ERRO no worker {proc_idx}: {e}\n{traceback.format_exc()}"
111
+ print(error_msg)
112
+ if progress_queue: progress_queue.put((proc_idx, -1, -1, error_msg))
113
+ return_queue.put((proc_idx, error_msg))
114
+
115
+ # -------------------------------------------------------------
116
+ # 3. CLASSE DO SERVIDOR PRINCIPAL
117
+ # -------------------------------------------------------------
118
+
119
+ class SeedVRServer:
120
+ def __init__(self, **kwargs):
121
+ """Inicializa o servidor, define os caminhos e prepara o ambiente."""
122
+ print("⚙️ SeedVRServer inicializando...")
123
+ self.SEEDVR_ROOT = SEEDVR_REPO_PATH
124
+ self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
125
+ self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
126
+ self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
127
+ self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
128
+ self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
129
+ self.NUM_GPUS_TOTAL = torch.cuda.device_count()
130
+
131
+ for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
132
+ p.mkdir(parents=True, exist_ok=True)
133
+
134
+ self.setup_dependencies()
135
+ print("📦 SeedVRServer pronto.")
136
+
137
+ def setup_dependencies(self):
138
+ """Garante que o repositório e os modelos estão presentes."""
139
+ # Clona o repositório do SeedVR se não existir
140
+ if not (self.SEEDVR_ROOT / ".git").exists():
141
+ print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
142
+ subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
143
+ else:
144
+ print("[SeedVRServer] Repositório SeedVR já existe.")
145
+
146
+ # Baixa os checkpoints do Hugging Face se não existirem
147
+ print(f"[SeedVRServer] Verificando checkpoints em {self.CKPTS_ROOT}...")
148
+ model_files = {
149
+ "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
150
+ "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses"
151
+ }
152
+ for filename, repo_id in model_files.items():
153
+ if not (self.CKPTS_ROOT / filename).exists():
154
+ print(f"Baixando {filename}...")
155
+ from huggingface_hub import hf_hub_download
156
+ hf_hub_download(
157
+ repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
158
+ cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
159
+ )
160
+ print("[SeedVRServer] Checkpoints estão no local correto.")
161
+
162
+ def run_inference(
163
+ self,
164
+ file_path: str, *,
165
+ seed: int,
166
+ resolution: int,
167
+ batch_size: int,
168
+ model: str = "seedvr2_ema_3b_fp16.safetensors",
169
+ fps: Optional[float] = None,
170
+ debug: bool = False,
171
+ preserve_vram: bool = True,
172
+ progress: Optional[Callable] = None
173
+ ) -> str:
174
+ """
175
+ Executa o pipeline completo de upscaling de vídeo e retorna o caminho do arquivo de saída.
176
+ """
177
+ if progress: progress(0.01, "⌛ Inicializando...")
178
+
179
+ # --- 1. Extração de Frames ---
180
+ if progress: progress(0.05, "🎬 Extraindo frames do vídeo...")
181
+ frames_tensor, original_fps = extract_frames_from_video(file_path, debug)
182
+
183
+ # --- 2. Preparação do Processamento Multi-GPU ---
184
+ device_list = list(range(self.NUM_GPUS_TOTAL))
185
+ num_devices = len(device_list)
186
+ chunks = torch.chunk(frames_tensor, num_devices, dim=0)
187
+
188
+ manager = mp.Manager()
189
+ return_queue = manager.Queue()
190
+ progress_queue = manager.Queue() if progress else None
191
+
192
+ shared_args = {
193
+ "model": model, "model_dir": str(self.CKPTS_ROOT), "preserve_vram": preserve_vram,
194
+ "debug": debug, "seed": seed, "resolution": resolution, "batch_size": batch_size
195
+ }
196
+
197
+ # --- 3. Inicia os Workers ---
198
+ if progress: progress(0.1, f"🚀 Iniciando geração em {num_devices} GPUs...")
199
+ workers = []
200
+ for idx, device_id in enumerate(device_list):
201
+ p = mp.Process(target=_worker_process, args=(idx, device_id, chunks[idx].cpu().numpy(), shared_args, return_queue, progress_queue))
202
+ p.start()
203
+ workers.append(p)
204
+
205
+ # --- 4. Coleta de Resultados e Monitoramento de Progresso ---
206
+ results_np = [None] * num_devices
207
+ finished_workers = 0
208
+ worker_progress = [0.0] * num_devices
209
+ while finished_workers < num_devices:
210
+ # Atualiza a barra de progresso com informações da fila
211
+ if progress_queue:
212
+ while not progress_queue.empty():
213
+ try:
214
+ p_idx, b_idx, b_total, msg = progress_queue.get_nowait()
215
+ if b_idx == -1: raise RuntimeError(f"Erro no Worker {p_idx}: {msg}")
216
+ if b_total > 0: worker_progress[p_idx] = b_idx / b_total
217
+ total_progress = sum(worker_progress) / num_devices
218
+ progress(0.1 + total_progress * 0.85, desc=f"GPU {p_idx+1}/{num_devices}: {msg}")
219
+ except queue.Empty: pass
220
+
221
+ # Verifica se algum worker terminou
222
+ try:
223
+ proc_idx, result = return_queue.get(timeout=0.2)
224
+ if isinstance(result, str): raise RuntimeError(f"Worker {proc_idx} falhou: {result}")
225
+ results_np[proc_idx] = result
226
+ worker_progress[proc_idx] = 1.0
227
+ finished_workers += 1
228
+ except queue.Empty: pass
229
+
230
+ for p in workers: p.join()
231
+
232
+ if any(r is None for r in results_np):
233
+ raise RuntimeError("Um ou mais workers falharam ao retornar um resultado.")
234
+
235
+ # --- 5. Combina os resultados e salva o vídeo final ---
236
+ result_tensor = torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)
237
+
238
+ if progress: progress(0.95, "💾 Salvando o vídeo final...")
239
+
240
+ out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
241
+ out_dir.mkdir(parents=True, exist_ok=True)
242
+ output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
243
+
244
+ final_fps = fps if fps and fps > 0 else original_fps
245
+ save_frames_to_video(result_tensor, str(output_filepath), final_fps, debug)
246
+
247
+ print(f"✅ Vídeo salvo com sucesso em: {output_filepath}")
248
+ return str(output_filepath)
249
+
250
+ # -------------------------------------------------------------
251
+ # 4. PONTO DE ENTRADA PARA EXECUÇÃO
252
+ # -------------------------------------------------------------
253
+
254
+ if __name__ == "__main__":
255
+ # Bloco para testes ou inicialização autônoma.
256
+ print("🚀 Executando o servidor SeedVR em modo autônomo...")
257
+ try:
258
+ server = SeedVRServer()
259
+ print("✅ Servidor inicializado com sucesso. Pronto para receber chamadas.")
260
+ # Exemplo de como chamar a inferência (requer um arquivo de vídeo):
261
+ # input_video = "caminho/para/seu/video.mp4"
262
+ # if os.path.exists(input_video):
263
+ # server.run_inference(
264
+ # file_path=input_video,
265
+ # seed=42,
266
+ # resolution=1072,
267
+ # batch_size=4,
268
+ # progress=lambda p, desc: print(f"Progresso: {p*100:.1f}% - {desc}")
269
+ # )
270
+ # else:
271
+ # print(f"Vídeo de teste não encontrado em '{input_video}'. Pulei a execução da inferência.")
272
+ except Exception as e:
273
+ print(f"❌ Falha ao inicializar o servidor: {e}")
274
+ import traceback
275
+ traceback.print_exc()
276
+ sys.exit(1)
277
+