EuuIia commited on
Commit
2991b2f
·
verified ·
1 Parent(s): c2b6094

Update api/seedvr_server.py

Browse files
Files changed (1) hide show
  1. api/seedvr_server.py +164 -52
api/seedvr_server.py CHANGED
@@ -3,31 +3,74 @@
3
  import os
4
  import sys
5
  import shutil
6
- import mimetypes
7
  import time
8
- import subprocess # Necessário para clonar o repositório na configuração inicial
 
 
9
  from pathlib import Path
10
  from typing import Optional, Callable
11
  from types import SimpleNamespace
12
 
 
 
 
 
13
  from huggingface_hub import hf_hub_download
14
 
15
- # Adiciona dinamicamente o caminho do repositório clonado ao sys.path.
16
- # Isso é crucial para que a importação do 'inference_cli' funcione.
17
  SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
18
  if str(SEEDVR_REPO_PATH) not in sys.path:
19
- # Insere no início da lista para garantir prioridade de importação.
20
  sys.path.insert(0, str(SEEDVR_REPO_PATH))
21
 
22
- # Tenta importar as funções necessárias APÓS a modificação do path.
23
- # Se falhar, a aplicação não pode continuar.
24
  try:
25
- from inference_cli import run_inference_logic, save_frames_to_video
 
 
26
  except ImportError as e:
27
- print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'.")
28
- print(f"Verifique se o repositório em '{SEEDVR_REPO_PATH}' está correto e completo.")
29
  raise e
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class SeedVRServer:
32
  def __init__(self, **kwargs):
33
  """
@@ -41,12 +84,12 @@ class SeedVRServer:
41
  self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
42
  self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
43
 
44
- print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
45
  for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
46
  p.mkdir(parents=True, exist_ok=True)
47
 
48
  self.setup_dependencies()
49
- print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")
50
 
51
  def setup_dependencies(self):
52
  """ Garante que o repositório e os modelos estão presentes. """
@@ -57,7 +100,6 @@ class SeedVRServer:
57
  """ Clona o repositório do SeedVR se ele não existir. """
58
  if not (self.SEEDVR_ROOT / ".git").exists():
59
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
60
- # Usamos subprocess.run aqui porque é uma tarefa de inicialização única.
61
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
62
  else:
63
  print("[SeedVRServer] Repositório SeedVR já existe.")
@@ -79,65 +121,135 @@ class SeedVRServer:
79
  cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
80
  )
81
  print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
82
-
83
- def run_inference_direct(
84
  self,
85
  file_path: str, *,
86
  seed: int, res_h: int, res_w: int, sp_size: int,
87
  fps: Optional[float] = None, progress: Optional[Callable] = None
88
  ) -> str:
89
  """
90
- Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
91
  """
92
- # Cria um diretório de saída único para salvar o resultado.
93
  out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
94
  out_dir.mkdir(parents=True, exist_ok=True)
95
  output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
96
 
97
- # Simula o objeto 'args' que a função de lógica do inference_cli espera.
98
- # Usamos SimpleNamespace para criar um objeto simples com atributos.
99
- args = SimpleNamespace(
100
- video_path=file_path,
101
- output=str(output_filepath),
102
- model_dir=str(self.CKPTS_ROOT),
103
- seed=seed,
104
- resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência.
105
- batch_size=sp_size,
106
- model="seedvr2_ema_3b_fp16.safetensors",
107
- preserve_vram=True,
108
- debug=True, # Mantém o debug ativo para logs detalhados.
109
- cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
110
- skip_first_frames=0,
111
- load_cap=0,
112
- output_format='video' # Garante que sempre gere vídeo
113
- )
114
-
115
  try:
116
- # Informa a UI que o processo começou.
117
- if progress:
118
- progress(0.01, "Initializing...")
119
 
120
- # Chama a função importada do script original, passando o callback de progresso.
121
- # Este callback será chamado de dentro da lógica de multi-processamento.
122
- result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
123
 
124
- # Informa a UI que a inferência terminou e o salvamento vai começar.
125
- if progress:
126
- progress(0.95, "Saving the final video...")
127
-
128
- # Define o FPS final: usa o valor da UI ou o original do vídeo de entrada.
 
 
 
129
  final_fps = fps if fps and fps > 0 else original_fps
130
- save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
131
 
132
  print(f"✅ Video saved successfully to: {output_filepath}")
133
-
134
- # Retorna o caminho do arquivo gerado para a UI.
135
  return str(output_filepath)
136
 
137
  except Exception as e:
138
- print(f"❌ Error during direct inference execution: {e}")
139
  import traceback
140
  traceback.print_exc()
141
- # Propaga o erro para a UI do Gradio, que o exibirá de forma amigável.
142
  raise
143
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import sys
5
  import shutil
 
6
  import time
7
+ import subprocess
8
+ import multiprocessing as mp
9
+ import queue
10
  from pathlib import Path
11
  from typing import Optional, Callable
12
  from types import SimpleNamespace
13
 
14
+ # --- Importações diretas de bibliotecas e do repositório SeedVR ---
15
+ import torch
16
+ import cv2
17
+ import numpy as np
18
  from huggingface_hub import hf_hub_download
19
 
20
+ # Adiciona o caminho do repositório para importar os módulos do 'src'
 
21
  SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
22
  if str(SEEDVR_REPO_PATH) not in sys.path:
 
23
  sys.path.insert(0, str(SEEDVR_REPO_PATH))
24
 
 
 
25
  try:
26
+ from src.core.generation import generation_loop
27
+ from src.core.model_manager import configure_runner
28
+ from src.utils.downloads import download_weight
29
  except ImportError as e:
30
+ print(f"ERRO FATAL: Não foi possível importar os módulos do SeedVR a partir de '{SEEDVR_REPO_PATH}'.")
 
31
  raise e
32
 
33
+ # --- Função do Worker (definida fora da classe para ser 'picklable' pelo multiprocessing) ---
34
+
35
+ def _worker_entry_point(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
36
+ """
37
+ Ponto de entrada para cada processo filho (worker).
38
+ Esta função executa em um processo separado e em uma GPU designada.
39
+ """
40
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
41
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
42
+
43
+ # Imports são feitos aqui para garantir um ambiente limpo para cada processo
44
+ import torch
45
+
46
+ frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
47
+
48
+ local_progress_callback = None
49
+ if progress_queue:
50
+ def callback_wrapper(batch_idx, total_batches, current_frames, message):
51
+ progress_queue.put((proc_idx, batch_idx, total_batches, message))
52
+ local_progress_callback = callback_wrapper
53
+
54
+ try:
55
+ # Cada worker configura seu próprio 'runner' para sua GPU
56
+ runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
57
+
58
+ result_tensor = generation_loop(
59
+ runner=runner, images=frames_tensor, cfg_scale=shared_args["cfg_scale"],
60
+ seed=shared_args["seed"], res_w=shared_args["res_w"], batch_size=shared_args["batch_size"],
61
+ preserve_vram=shared_args["preserve_vram"], temporal_overlap=shared_args["temporal_overlap"],
62
+ debug=shared_args["debug"],
63
+ progress_callback=local_progress_callback
64
+ )
65
+ return_queue.put((proc_idx, result_tensor.cpu().numpy()))
66
+ except Exception as e:
67
+ import traceback
68
+ error_msg = f"ERROR in worker {proc_idx}: {e}\n{traceback.format_exc()}"
69
+ print(error_msg)
70
+ if progress_queue: progress_queue.put((proc_idx, -1, -1, error_msg))
71
+ return_queue.put((proc_idx, error_msg))
72
+
73
+
74
  class SeedVRServer:
75
  def __init__(self, **kwargs):
76
  """
 
84
  self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
85
  self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
86
 
87
+ print("🚀 SeedVRServer (Arquitetura Integrada) inicializando...")
88
  for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
89
  p.mkdir(parents=True, exist_ok=True)
90
 
91
  self.setup_dependencies()
92
+ print("✅ SeedVRServer (Arquitetura Integrada) pronto.")
93
 
94
  def setup_dependencies(self):
95
  """ Garante que o repositório e os modelos estão presentes. """
 
100
  """ Clona o repositório do SeedVR se ele não existir. """
101
  if not (self.SEEDVR_ROOT / ".git").exists():
102
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
 
103
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
104
  else:
105
  print("[SeedVRServer] Repositório SeedVR já existe.")
 
121
  cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
122
  )
123
  print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
124
+
125
+ def run_inference(
126
  self,
127
  file_path: str, *,
128
  seed: int, res_h: int, res_w: int, sp_size: int,
129
  fps: Optional[float] = None, progress: Optional[Callable] = None
130
  ) -> str:
131
  """
132
+ Método público principal para executar a inferência.
133
  """
 
134
  out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
135
  out_dir.mkdir(parents=True, exist_ok=True)
136
  output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  try:
139
+ if progress: progress(0.05, "Extracting frames...")
140
+ frames_tensor, original_fps = self._extract_frames(file_path, debug=True)
 
141
 
142
+ if progress: progress(0.1, "Starting parallel processing...")
 
 
143
 
144
+ # Chama o método de processamento paralelo
145
+ result_tensor = self._process_in_parallel(
146
+ frames_tensor=frames_tensor,
147
+ seed=seed, resolution=res_h, batch_size=sp_size,
148
+ progress_callback=progress
149
+ )
150
+
151
+ if progress: progress(0.95, "Saving the final video...")
152
  final_fps = fps if fps and fps > 0 else original_fps
153
+ self._save_video(result_tensor, str(output_filepath), final_fps, debug=True)
154
 
155
  print(f"✅ Video saved successfully to: {output_filepath}")
 
 
156
  return str(output_filepath)
157
 
158
  except Exception as e:
159
+ print(f"❌ Error during inference execution: {e}")
160
  import traceback
161
  traceback.print_exc()
 
162
  raise
163
+
164
+ def _extract_frames(self, video_path: str, debug: bool = False):
165
+ """Método privado para extrair quadros de um vídeo."""
166
+ # (Este é o código da função extract_frames_from_video, agora como um método de classe)
167
+ if debug: print(f"🎬 Extracting frames from video: {video_path}")
168
+ if not os.path.exists(video_path): raise FileNotFoundError(f"Video file not found: {video_path}")
169
+ cap = cv2.VideoCapture(video_path)
170
+ if not cap.isOpened(): raise ValueError(f"Cannot open video file: {video_path}")
171
+ fps = cap.get(cv2.CAP_PROP_FPS)
172
+ frames = []
173
+ while True:
174
+ ret, frame = cap.read()
175
+ if not ret: break
176
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
177
+ frame = frame.astype(np.float32) / 255.0
178
+ frames.append(frame)
179
+ cap.release()
180
+ if not frames: raise ValueError(f"No frames extracted from video: {video_path}")
181
+ return torch.from_numpy(np.stack(frames)).to(torch.float16), fps
182
+
183
+ def _save_video(self, frames_tensor: torch.Tensor, output_path: str, fps: float, debug: bool = False):
184
+ """Método privado para salvar um tensor de quadros em um arquivo de vídeo."""
185
+ # (Este é o código da função save_frames_to_video, agora como um método de classe)
186
+ if debug: print(f"🎬 Saving {frames_tensor.shape[0]} frames to video: {output_path}")
187
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
188
+ frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
189
+ _, H, W, _ = frames_np.shape
190
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
191
+ out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
192
+ for frame in frames_np:
193
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
194
+ out.release()
195
+ if debug: print(f"✅ Video saved successfully: {output_path}")
196
+
197
+ def _process_in_parallel(self, frames_tensor: torch.Tensor, seed: int, resolution: int, batch_size: int, progress_callback: Optional[Callable] = None):
198
+ """
199
+ Método privado que gerencia a lógica de multiprocessamento.
200
+ """
201
+ device_list = list(range(self.NUM_GPUS_TOTAL))
202
+ num_devices = len(device_list)
203
+ chunks = torch.chunk(frames_tensor, num_devices, dim=0)
204
+
205
+ manager = mp.Manager()
206
+ return_queue = manager.Queue()
207
+ progress_queue = manager.Queue() if progress_callback else None
208
+ workers = []
209
+
210
+ shared_args = {
211
+ "model": "seedvr2_ema_3b_fp16.safetensors", "model_dir": str(self.CKPTS_ROOT),
212
+ "preserve_vram": True, "debug": True, "cfg_scale": 1.0,
213
+ "seed": seed, "res_w": resolution, "batch_size": batch_size, "temporal_overlap": 0,
214
+ }
215
+
216
+ for idx, device_id in enumerate(device_list):
217
+ p = mp.Process(target=_worker_entry_point, args=(idx, device_id, chunks[idx].cpu().numpy(), shared_args, return_queue, progress_queue))
218
+ p.start()
219
+ workers.append(p)
220
+
221
+ results_np = [None] * num_devices
222
+ finished_workers_count = 0
223
+ worker_progress = [0.0] * num_devices
224
+
225
+ while finished_workers_count < num_devices:
226
+ if progress_queue:
227
+ while not progress_queue.empty():
228
+ try:
229
+ proc_idx, batch_idx, total_batches, message = progress_queue.get_nowait()
230
+ if batch_idx == -1: raise RuntimeError(f"Worker {proc_idx} error: {message}")
231
+ if total_batches > 0: worker_progress[proc_idx] = batch_idx / total_batches
232
+ total_progress = sum(worker_progress) / num_devices
233
+ progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: {message}")
234
+ except queue.Empty:
235
+ break
236
+
237
+ try:
238
+ proc_idx, result = return_queue.get(timeout=0.2)
239
+ if isinstance(result, str) and result.startswith("ERROR"):
240
+ raise RuntimeError(f"Worker {proc_idx} failed: {result}")
241
+ results_np[proc_idx] = result
242
+ worker_progress[proc_idx] = 1.0
243
+ finished_workers_count += 1
244
+ if progress_callback:
245
+ total_progress = sum(worker_progress) / num_devices
246
+ progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: Completed!")
247
+ except queue.Empty:
248
+ pass
249
+
250
+ for p in workers: p.join()
251
+
252
+ if any(r is None for r in results_np):
253
+ raise RuntimeError("One or more workers failed to return a result.")
254
+
255
+ return torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)