EuuIia commited on
Commit
9ad5b00
·
verified ·
1 Parent(s): 296cade

Update api/seedvr_server.py

Browse files
Files changed (1) hide show
  1. api/seedvr_server.py +52 -164
api/seedvr_server.py CHANGED
@@ -3,74 +3,31 @@
3
  import os
4
  import sys
5
  import shutil
 
6
  import time
7
- import subprocess
8
- import multiprocessing as mp
9
- import queue
10
  from pathlib import Path
11
  from typing import Optional, Callable
12
  from types import SimpleNamespace
13
 
14
- # --- Importações diretas de bibliotecas e do repositório SeedVR ---
15
- import torch
16
- import cv2
17
- import numpy as np
18
  from huggingface_hub import hf_hub_download
19
 
20
- # Adiciona o caminho do repositório para importar os módulos do 'src'
 
21
  SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
22
  if str(SEEDVR_REPO_PATH) not in sys.path:
 
23
  sys.path.insert(0, str(SEEDVR_REPO_PATH))
24
 
 
 
25
  try:
26
- from src.core.generation import generation_loop
27
- from src.core.model_manager import configure_runner
28
- from src.utils.downloads import download_weight
29
  except ImportError as e:
30
- print(f"ERRO FATAL: Não foi possível importar os módulos do SeedVR a partir de '{SEEDVR_REPO_PATH}'.")
 
31
  raise e
32
 
33
- # --- Função do Worker (definida fora da classe para ser 'picklable' pelo multiprocessing) ---
34
-
35
- def _worker_entry_point(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
36
- """
37
- Ponto de entrada para cada processo filho (worker).
38
- Esta função executa em um processo separado e em uma GPU designada.
39
- """
40
- os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
41
- os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
42
-
43
- # Imports são feitos aqui para garantir um ambiente limpo para cada processo
44
- import torch
45
-
46
- frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
47
-
48
- local_progress_callback = None
49
- if progress_queue:
50
- def callback_wrapper(batch_idx, total_batches, current_frames, message):
51
- progress_queue.put((proc_idx, batch_idx, total_batches, message))
52
- local_progress_callback = callback_wrapper
53
-
54
- try:
55
- # Cada worker configura seu próprio 'runner' para sua GPU
56
- runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
57
-
58
- result_tensor = generation_loop(
59
- runner=runner, images=frames_tensor, cfg_scale=shared_args["cfg_scale"],
60
- seed=shared_args["seed"], res_w=shared_args["res_w"], batch_size=shared_args["batch_size"],
61
- preserve_vram=shared_args["preserve_vram"], temporal_overlap=shared_args["temporal_overlap"],
62
- debug=shared_args["debug"],
63
- progress_callback=local_progress_callback
64
- )
65
- return_queue.put((proc_idx, result_tensor.cpu().numpy()))
66
- except Exception as e:
67
- import traceback
68
- error_msg = f"ERROR in worker {proc_idx}: {e}\n{traceback.format_exc()}"
69
- print(error_msg)
70
- if progress_queue: progress_queue.put((proc_idx, -1, -1, error_msg))
71
- return_queue.put((proc_idx, error_msg))
72
-
73
-
74
  class SeedVRServer:
75
  def __init__(self, **kwargs):
76
  """
@@ -84,12 +41,12 @@ class SeedVRServer:
84
  self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
85
  self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
86
 
87
- print("🚀 SeedVRServer (Arquitetura Integrada) inicializando...")
88
  for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
89
  p.mkdir(parents=True, exist_ok=True)
90
 
91
  self.setup_dependencies()
92
- print("✅ SeedVRServer (Arquitetura Integrada) pronto.")
93
 
94
  def setup_dependencies(self):
95
  """ Garante que o repositório e os modelos estão presentes. """
@@ -100,6 +57,7 @@ class SeedVRServer:
100
  """ Clona o repositório do SeedVR se ele não existir. """
101
  if not (self.SEEDVR_ROOT / ".git").exists():
102
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
 
103
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
104
  else:
105
  print("[SeedVRServer] Repositório SeedVR já existe.")
@@ -121,135 +79,65 @@ class SeedVRServer:
121
  cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
122
  )
123
  print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
124
-
125
- def run_inference(
126
  self,
127
  file_path: str, *,
128
  seed: int, res_h: int, res_w: int, sp_size: int,
129
  fps: Optional[float] = None, progress: Optional[Callable] = None
130
  ) -> str:
131
  """
132
- Método público principal para executar a inferência.
133
  """
 
134
  out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
135
  out_dir.mkdir(parents=True, exist_ok=True)
136
  output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  try:
139
- if progress: progress(0.05, "Extracting frames...")
140
- frames_tensor, original_fps = self._extract_frames(file_path, debug=True)
141
-
142
- if progress: progress(0.1, "Starting parallel processing...")
143
 
144
- # Chama o método de processamento paralelo
145
- result_tensor = self._process_in_parallel(
146
- frames_tensor=frames_tensor,
147
- seed=seed, resolution=res_h, batch_size=sp_size,
148
- progress_callback=progress
149
- )
150
 
151
- if progress: progress(0.95, "Saving the final video...")
 
 
 
 
152
  final_fps = fps if fps and fps > 0 else original_fps
153
- self._save_video(result_tensor, str(output_filepath), final_fps, debug=True)
154
 
155
  print(f"✅ Video saved successfully to: {output_filepath}")
 
 
156
  return str(output_filepath)
157
 
158
  except Exception as e:
159
- print(f"❌ Error during inference execution: {e}")
160
  import traceback
161
  traceback.print_exc()
 
162
  raise
163
-
164
- def _extract_frames(self, video_path: str, debug: bool = False):
165
- """Método privado para extrair quadros de um vídeo."""
166
- # (Este é o código da função extract_frames_from_video, agora como um método de classe)
167
- if debug: print(f"🎬 Extracting frames from video: {video_path}")
168
- if not os.path.exists(video_path): raise FileNotFoundError(f"Video file not found: {video_path}")
169
- cap = cv2.VideoCapture(video_path)
170
- if not cap.isOpened(): raise ValueError(f"Cannot open video file: {video_path}")
171
- fps = cap.get(cv2.CAP_PROP_FPS)
172
- frames = []
173
- while True:
174
- ret, frame = cap.read()
175
- if not ret: break
176
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
177
- frame = frame.astype(np.float32) / 255.0
178
- frames.append(frame)
179
- cap.release()
180
- if not frames: raise ValueError(f"No frames extracted from video: {video_path}")
181
- return torch.from_numpy(np.stack(frames)).to(torch.float16), fps
182
-
183
- def _save_video(self, frames_tensor: torch.Tensor, output_path: str, fps: float, debug: bool = False):
184
- """Método privado para salvar um tensor de quadros em um arquivo de vídeo."""
185
- # (Este é o código da função save_frames_to_video, agora como um método de classe)
186
- if debug: print(f"🎬 Saving {frames_tensor.shape[0]} frames to video: {output_path}")
187
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
188
- frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
189
- _, H, W, _ = frames_np.shape
190
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
191
- out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
192
- for frame in frames_np:
193
- out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
194
- out.release()
195
- if debug: print(f"✅ Video saved successfully: {output_path}")
196
-
197
- def _process_in_parallel(self, frames_tensor: torch.Tensor, seed: int, resolution: int, batch_size: int, progress_callback: Optional[Callable] = None):
198
- """
199
- Método privado que gerencia a lógica de multiprocessamento.
200
- """
201
- device_list = list(range(self.NUM_GPUS_TOTAL))
202
- num_devices = len(device_list)
203
- chunks = torch.chunk(frames_tensor, num_devices, dim=0)
204
-
205
- manager = mp.Manager()
206
- return_queue = manager.Queue()
207
- progress_queue = manager.Queue() if progress_callback else None
208
- workers = []
209
-
210
- shared_args = {
211
- "model": "seedvr2_ema_3b_fp16.safetensors", "model_dir": str(self.CKPTS_ROOT),
212
- "preserve_vram": True, "debug": True, "cfg_scale": 1.0,
213
- "seed": seed, "res_w": resolution, "batch_size": batch_size, "temporal_overlap": 0,
214
- }
215
-
216
- for idx, device_id in enumerate(device_list):
217
- p = mp.Process(target=_worker_entry_point, args=(idx, device_id, chunks[idx].cpu().numpy(), shared_args, return_queue, progress_queue))
218
- p.start()
219
- workers.append(p)
220
-
221
- results_np = [None] * num_devices
222
- finished_workers_count = 0
223
- worker_progress = [0.0] * num_devices
224
-
225
- while finished_workers_count < num_devices:
226
- if progress_queue:
227
- while not progress_queue.empty():
228
- try:
229
- proc_idx, batch_idx, total_batches, message = progress_queue.get_nowait()
230
- if batch_idx == -1: raise RuntimeError(f"Worker {proc_idx} error: {message}")
231
- if total_batches > 0: worker_progress[proc_idx] = batch_idx / total_batches
232
- total_progress = sum(worker_progress) / num_devices
233
- progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: {message}")
234
- except queue.Empty:
235
- break
236
-
237
- try:
238
- proc_idx, result = return_queue.get(timeout=0.2)
239
- if isinstance(result, str) and result.startswith("ERROR"):
240
- raise RuntimeError(f"Worker {proc_idx} failed: {result}")
241
- results_np[proc_idx] = result
242
- worker_progress[proc_idx] = 1.0
243
- finished_workers_count += 1
244
- if progress_callback:
245
- total_progress = sum(worker_progress) / num_devices
246
- progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: Completed!")
247
- except queue.Empty:
248
- pass
249
-
250
- for p in workers: p.join()
251
-
252
- if any(r is None for r in results_np):
253
- raise RuntimeError("One or more workers failed to return a result.")
254
-
255
- return torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)
 
3
  import os
4
  import sys
5
  import shutil
6
+ import mimetypes
7
  import time
8
+ import subprocess # Necessário para clonar o repositório na configuração inicial
 
 
9
  from pathlib import Path
10
  from typing import Optional, Callable
11
  from types import SimpleNamespace
12
 
 
 
 
 
13
  from huggingface_hub import hf_hub_download
14
 
15
+ # Adiciona dinamicamente o caminho do repositório clonado ao sys.path.
16
+ # Isso é crucial para que a importação do 'inference_cli' funcione.
17
  SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
18
  if str(SEEDVR_REPO_PATH) not in sys.path:
19
+ # Insere no início da lista para garantir prioridade de importação.
20
  sys.path.insert(0, str(SEEDVR_REPO_PATH))
21
 
22
+ # Tenta importar as funções necessárias APÓS a modificação do path.
23
+ # Se falhar, a aplicação não pode continuar.
24
  try:
25
+ from inference_cli import run_inference_logic, save_frames_to_video
 
 
26
  except ImportError as e:
27
+ print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'.")
28
+ print(f"Verifique se o repositório em '{SEEDVR_REPO_PATH}' está correto e completo.")
29
  raise e
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class SeedVRServer:
32
  def __init__(self, **kwargs):
33
  """
 
41
  self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
42
  self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
43
 
44
+ print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
45
  for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
46
  p.mkdir(parents=True, exist_ok=True)
47
 
48
  self.setup_dependencies()
49
+ print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")
50
 
51
  def setup_dependencies(self):
52
  """ Garante que o repositório e os modelos estão presentes. """
 
57
  """ Clona o repositório do SeedVR se ele não existir. """
58
  if not (self.SEEDVR_ROOT / ".git").exists():
59
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
60
+ # Usamos subprocess.run aqui porque é uma tarefa de inicialização única.
61
  subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
62
  else:
63
  print("[SeedVRServer] Repositório SeedVR já existe.")
 
79
  cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
80
  )
81
  print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
82
+
83
+ def run_inference_direct(
84
  self,
85
  file_path: str, *,
86
  seed: int, res_h: int, res_w: int, sp_size: int,
87
  fps: Optional[float] = None, progress: Optional[Callable] = None
88
  ) -> str:
89
  """
90
+ Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
91
  """
92
+ # Cria um diretório de saída único para salvar o resultado.
93
  out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
94
  out_dir.mkdir(parents=True, exist_ok=True)
95
  output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
96
 
97
+ # Simula o objeto 'args' que a função de lógica do inference_cli espera.
98
+ # Usamos SimpleNamespace para criar um objeto simples com atributos.
99
+ args = SimpleNamespace(
100
+ video_path=file_path,
101
+ output=str(output_filepath),
102
+ model_dir=str(self.CKPTS_ROOT),
103
+ seed=seed,
104
+ resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência.
105
+ batch_size=sp_size,
106
+ model="seedvr2_ema_3b_fp16.safetensors",
107
+ preserve_vram=True,
108
+ debug=True, # Mantém o debug ativo para logs detalhados.
109
+ cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
110
+ skip_first_frames=0,
111
+ load_cap=0,
112
+ output_format='video' # Garante que sempre gere vídeo
113
+ )
114
+
115
  try:
116
+ # Informa a UI que o processo começou.
117
+ if progress:
118
+ progress(0.01, "Initializing...")
 
119
 
120
+ # Chama a função importada do script original, passando o callback de progresso.
121
+ # Este callback será chamado de dentro da lógica de multi-processamento.
122
+ result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
 
 
 
123
 
124
+ # Informa a UI que a inferência terminou e o salvamento vai começar.
125
+ if progress:
126
+ progress(0.95, "Saving the final video...")
127
+
128
+ # Define o FPS final: usa o valor da UI ou o original do vídeo de entrada.
129
  final_fps = fps if fps and fps > 0 else original_fps
130
+ save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
131
 
132
  print(f"✅ Video saved successfully to: {output_filepath}")
133
+
134
+ # Retorna o caminho do arquivo gerado para a UI.
135
  return str(output_filepath)
136
 
137
  except Exception as e:
138
+ print(f"❌ Error during direct inference execution: {e}")
139
  import traceback
140
  traceback.print_exc()
141
+ # Propaga o erro para a UI do Gradio, que o exibirá de forma amigável.
142
  raise
143
+