EuuIia commited on
Commit
5aa9c14
·
verified ·
1 Parent(s): 8a8a15b

Update api/seedvr_server.py

Browse files
Files changed (1) hide show
  1. api/seedvr_server.py +345 -15
api/seedvr_server.py CHANGED
@@ -21,12 +21,342 @@ if str(SEEDVR_REPO_PATH) not in sys.path:
21
 
22
  # Tenta importar as funções necessárias APÓS a modificação do path.
23
  # Se falhar, a aplicação não pode continuar.
24
- try:
25
- from inference_cli import run_inference_logic, save_frames_to_video
26
- except ImportError as e:
27
- print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'.")
28
- print(f"Verifique se o repositório em '{SEEDVR_REPO_PATH}' está correto e completo.")
29
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
 
@@ -46,16 +376,16 @@ class SeedVRServer:
46
  self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
47
 
48
 
49
- if INIT:
50
- print("🚀 SeedVRServer ja inicializando...")
51
- else:
52
- print("⚙️ SeedVRServer (Modo de Chamada Direta) inicializando...")
53
- for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
54
- p.mkdir(parents=True, exist_ok=True)
55
 
56
- self.setup_dependencies()
57
- print("📦 SeedVRServer (Modo de Chamada Direta) pronto.")
58
- INIT = True
59
 
60
 
61
  def setup_dependencies(self):
 
21
 
22
  # Tenta importar as funções necessárias APÓS a modificação do path.
23
  # Se falhar, a aplicação não pode continuar.
24
+ #try:
25
+ # from inference_cli import run_inference_logic, save_frames_to_video
26
+ #except ImportError as e:
27
+ # print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'.")
28
+ # print(f"Verifique se o repositório em '{SEEDVR_REPO_PATH}' está correto e completo.")
29
+ # raise e
30
+
31
+
32
+
33
+ #!/usr/bin/env python3
34
+ """
35
+ Standalone SeedVR2 Video Upscaler CLI Script
36
+ (MODIFICADO PARA SUPORTE ROBUSTO A CALLBACKS EM MULTIPROCESSING)
37
+ """
38
+
39
+ import sys
40
+ import os
41
+ import argparse
42
+ import time
43
+ import multiprocessing as mp
44
+ import queue # Importa a classe de exceção para filas vazias
45
+
46
+ # Garante o uso seguro de CUDA com multiprocessing, essencial para estabilidade.
47
+ if mp.get_start_method(allow_none=True) != 'spawn':
48
+ mp.set_start_method('spawn', force=True)
49
+
50
+ # -------------------------------------------------------------
51
+ # 1) Configuração de alocação de memória da VRAM
52
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
53
+
54
+ # 2) Pré-análise dos argumentos para configurar a visibilidade dos dispositivos CUDA
55
+ _pre_parser = argparse.ArgumentParser(add_help=False)
56
+ _pre_parser.add_argument("--cuda_device", type=str, default=None)
57
+ _pre_args, _ = _pre_parser.parse_known_args()
58
+ if _pre_args.cuda_device is not None:
59
+ device_list_env = [x.strip() for x in _pre_args.cuda_device.split(',') if x.strip()!='']
60
+ if len(device_list_env) == 1:
61
+ os.environ["CUDA_VISIBLE_DEVICES"] = device_list_env[0]
62
+
63
+ # -------------------------------------------------------------
64
+ # 3) Importações pesadas (torch, etc.) são feitas após a configuração do ambiente.
65
+ import torch
66
+ import cv2
67
+ import numpy as np
68
+ from datetime import datetime
69
+ from pathlib import Path
70
+
71
+ # Adiciona o diretório raiz do projeto ao path do sistema para permitir importações de `src`
72
+ # Isso assume que o script está dentro do repositório clonado.
73
+ script_dir = os.path.dirname(os.path.abspath(__file__))
74
+ if script_dir not in sys.path:
75
+ sys.path.insert(0, script_dir)
76
+
77
+ # Importa as funções do SeedVR DEPOIS de ajustar o path.
78
+ from src.utils.downloads import download_weight
79
+
80
+ def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load_cap=None):
81
+ """
82
+ Extrai quadros de um vídeo e os converte para o formato de tensor.
83
+ """
84
+ if debug:
85
+ print(f"🎬 Extracting frames from video: {video_path}")
86
+
87
+ if not os.path.exists(video_path):
88
+ raise FileNotFoundError(f"Video file not found: {video_path}")
89
+
90
+ cap = cv2.VideoCapture(video_path)
91
+ if not cap.isOpened():
92
+ raise ValueError(f"Cannot open video file: {video_path}")
93
+
94
+ fps = cap.get(cv2.CAP_PROP_FPS)
95
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
96
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
97
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
98
+
99
+ if debug:
100
+ print(f"📊 Video info: {frame_count} frames, {width}x{height}, {fps:.2f} FPS")
101
+ if skip_first_frames:
102
+ print(f"⏭️ Will skip first {skip_first_frames} frames")
103
+ if load_cap and load_cap > 0:
104
+ print(f"🔢 Will load maximum {load_cap} frames")
105
+
106
+ frames = []
107
+ frame_idx = 0
108
+ frames_loaded = 0
109
+
110
+ while True:
111
+ ret, frame = cap.read()
112
+ if not ret:
113
+ break
114
+
115
+ if frame_idx < skip_first_frames:
116
+ frame_idx += 1
117
+ continue
118
+
119
+ if load_cap is not None and load_cap > 0 and frames_loaded >= load_cap:
120
+ break
121
+
122
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
123
+ frame = frame.astype(np.float32) / 255.0
124
+
125
+ frames.append(frame)
126
+ frame_idx += 1
127
+ frames_loaded += 1
128
+
129
+ if debug and frames_loaded % 100 == 0:
130
+ total_to_load = min(frame_count, load_cap) if (load_cap and load_cap > 0) else frame_count
131
+ print(f"📹 Extracted {frames_loaded}/{total_to_load} frames")
132
+
133
+ cap.release()
134
+
135
+ if len(frames) == 0:
136
+ raise ValueError(f"No frames extracted from video: {video_path}")
137
+
138
+ if debug:
139
+ print(f"✅ Extracted {len(frames)} frames")
140
+
141
+ frames_tensor = torch.from_numpy(np.stack(frames)).to(torch.float16)
142
+
143
+ if debug:
144
+ print(f"📊 Frames tensor shape: {frames_tensor.shape}, dtype: {frames_tensor.dtype}")
145
+
146
+ return frames_tensor, fps
147
+
148
+
149
+ def save_frames_to_video(frames_tensor, output_path, fps=30.0, debug=False):
150
+ """
151
+ Salva um tensor de quadros em um arquivo de vídeo.
152
+ """
153
+ if debug:
154
+ print(f"🎬 Saving {frames_tensor.shape[0]} frames to video: {output_path}")
155
+
156
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
157
+
158
+ frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
159
+
160
+ T, H, W, C = frames_np.shape
161
+
162
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
163
+ out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
164
+
165
+ if not out.isOpened():
166
+ raise ValueError(f"Cannot create video writer for: {output_path}")
167
+
168
+ for i, frame in enumerate(frames_np):
169
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
170
+ out.write(frame_bgr)
171
+
172
+ if debug and (i + 1) % 100 == 0:
173
+ print(f"💾 Saved {i + 1}/{T} frames")
174
+
175
+ out.release()
176
+
177
+ if debug:
178
+ print(f"✅ Video saved successfully: {output_path}")
179
+
180
+ def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
181
+ """
182
+ Processo filho (worker) que executa o upscaling em uma GPU dedicada.
183
+ """
184
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
185
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
186
+
187
+ import torch
188
+ from src.core.model_manager import configure_runner
189
+ from src.core.generation import generation_loop
190
+
191
+ frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
192
+
193
+ local_progress_callback = None
194
+ if progress_queue:
195
+ def callback_wrapper(batch_idx, total_batches, current_frames, message):
196
+ # Envia uma tupla com informações de progresso para a fila
197
+ progress_queue.put((proc_idx, batch_idx, total_batches, message))
198
+ local_progress_callback = callback_wrapper
199
+
200
+ try:
201
+ runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
202
+ result_tensor = generation_loop(
203
+ runner=runner, images=frames_tensor, cfg_scale=shared_args["cfg_scale"],
204
+ seed=shared_args["seed"], res_w=shared_args["res_w"], batch_size=shared_args["batch_size"],
205
+ preserve_vram=shared_args["preserve_vram"], temporal_overlap=shared_args["temporal_overlap"],
206
+ debug=shared_args["debug"],
207
+ progress_callback=local_progress_callback
208
+ )
209
+ return_queue.put((proc_idx, result_tensor.cpu().numpy()))
210
+ except Exception as e:
211
+ import traceback
212
+ error_msg = f"ERROR in worker {proc_idx}: {e}\n{traceback.format_exc()}"
213
+ print(error_msg)
214
+ if progress_queue:
215
+ progress_queue.put((proc_idx, -1, -1, error_msg))
216
+ return_queue.put((proc_idx, error_msg))
217
+
218
+ def _gpu_processing(frames_tensor, device_list, args, progress_callback=None):
219
+ """
220
+ Divide os quadros, gerencia os workers e monitora o progresso de forma robusta.
221
+ """
222
+ num_devices = len(device_list)
223
+ chunks = torch.chunk(frames_tensor, num_devices, dim=0)
224
+
225
+ manager = mp.Manager()
226
+ return_queue = manager.Queue()
227
+ progress_queue = manager.Queue() if progress_callback else None
228
+ workers = []
229
+
230
+ shared_args = {
231
+ "model": args.model, "model_dir": args.model_dir or "./models/SEEDVR2",
232
+ "preserve_vram": args.preserve_vram, "debug": args.debug, "cfg_scale": 1.0,
233
+ "seed": args.seed, "res_w": args.resolution, "batch_size": args.batch_size, "temporal_overlap": 0,
234
+ }
235
+
236
+ for idx, (device_id, chunk_tensor) in enumerate(zip(device_list, chunks)):
237
+ p = mp.Process(target=_worker_process, args=(idx, device_id, chunk_tensor.cpu().numpy(), shared_args, return_queue, progress_queue))
238
+ p.start()
239
+ workers.append(p)
240
+
241
+ results_np = [None] * num_devices
242
+ finished_workers_count = 0
243
+ worker_progress = [0.0] * num_devices
244
+
245
+ while finished_workers_count < num_devices:
246
+ if progress_queue:
247
+ while not progress_queue.empty():
248
+ try:
249
+ proc_idx, batch_idx, total_batches, message = progress_queue.get_nowait()
250
+ if batch_idx == -1:
251
+ raise RuntimeError(f"Worker {proc_idx} error: {message}")
252
+ if total_batches > 0:
253
+ worker_progress[proc_idx] = batch_idx / total_batches
254
+ total_progress = sum(worker_progress) / num_devices
255
+ progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: {message}")
256
+ except queue.Empty:
257
+ break
258
+
259
+ try:
260
+ proc_idx, result = return_queue.get(timeout=0.2)
261
+ if isinstance(result, str) and result.startswith("ERROR"):
262
+ raise RuntimeError(f"Worker {proc_idx} failed: {result}")
263
+ results_np[proc_idx] = result
264
+ worker_progress[proc_idx] = 1.0
265
+ finished_workers_count += 1
266
+ if progress_callback:
267
+ total_progress = sum(worker_progress) / num_devices
268
+ progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: Completed!")
269
+ except queue.Empty:
270
+ pass
271
+
272
+ for p in workers: p.join()
273
+
274
+ if any(r is None for r in results_np):
275
+ raise RuntimeError("One or more workers failed to return a result.")
276
+
277
+ return torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)
278
+
279
+ def parse_arguments():
280
+ """Analisa os argumentos da linha de comando."""
281
+ parser = argparse.ArgumentParser(description="SeedVR2 Video Upscaler CLI")
282
+ parser.add_argument("--video_path", type=str, required=True, help="Path to input video file")
283
+ parser.add_argument("--seed", type=int, default=100, help="Random seed for generation (default: 100)")
284
+ parser.add_argument("--resolution", type=int, default=1072, help="Target resolution of the short side (default: 1072)")
285
+ parser.add_argument("--batch_size", type=int, default=5, help="Number of frames per batch (default: 5)")
286
+ parser.add_argument("--model", type=str, default="seedvr2_ema_3b_fp16.safetensors",
287
+ choices=["seedvr2_ema_3b_fp16.safetensors", "seedvr2_ema_3b_fp8_e4m3fn.safetensors",
288
+ "seedvr2_ema_7b_fp16.safetensors", "seedvr2_ema_7b_fp8_e4m3fn.safetensors"],
289
+ help="Model to use")
290
+ parser.add_argument("--model_dir", type=str, default=None, help="Directory containing the model files")
291
+ parser.add_argument("--skip_first_frames", type=int, default=0, help="Skip the first frames during processing")
292
+ parser.add_argument("--load_cap", type=int, default=0, help="Maximum number of frames to load from video (default: load all)")
293
+ parser.add_argument("--output", type=str, default=None, help="Output path")
294
+ parser.add_argument("--output_format", type=str, default="video", choices=["video", "png"], help="Output format: 'video' (mp4) or 'png' images")
295
+ parser.add_argument("--preserve_vram", action="store_true", help="Enable VRAM preservation mode")
296
+ parser.add_argument("--debug", action="store_true", help="Enable debug logging")
297
+ parser.add_argument("--cuda_device", type=str, default=None, help="CUDA device id(s). e.g., '0' or '0,1' for multi-GPU")
298
+
299
+ return parser.parse_args()
300
+
301
+ def run_inference_logic(args, progress_callback=None):
302
+ """
303
+ Função principal que executa o pipeline de upscaling. Pode ser importada e chamada por outros scripts.
304
+ """
305
+ if args.debug:
306
+ print(f"📋 Argumentos da Lógica de Inferência: {vars(args)}")
307
+
308
+ if progress_callback: progress_callback(0.05, "Extracting frames...")
309
+ print("🎬 Extraindo frames do vídeo...")
310
+ start_time = time.time()
311
+ frames_tensor, original_fps = extract_frames_from_video(
312
+ args.video_path, args.debug, args.skip_first_frames, args.load_cap
313
+ )
314
+ if args.debug:
315
+ print(f"🔄 Tempo de extração de frames: {time.time() - start_time:.2f}s")
316
+
317
+ device_list = [d.strip() for d in str(args.cuda_device).split(',') if d.strip()] if args.cuda_device else ["0"]
318
+ if args.debug:
319
+ print(f"🚀 Usando dispositivos: {device_list}")
320
+
321
+ if progress_callback: progress_callback(0.1, "Starting generation...")
322
+ processing_start = time.time()
323
+ download_weight(args.model, args.model_dir or "seedvr_models")
324
+
325
+ result_tensor = _gpu_processing(frames_tensor, device_list, args, progress_callback)
326
+
327
+ generation_time = time.time() - processing_start
328
+ if args.debug:
329
+ print(f"🔄 Tempo de Geração: {generation_time:.2f}s")
330
+ print(f"📊 Resultado: {result_tensor.shape}, dtype: {result_tensor.dtype}")
331
+
332
+ return result_tensor, original_fps, generation_time, len(frames_tensor)
333
+
334
+ def main():
335
+ """
336
+ Função principal para execução via linha de comando (CLI).
337
+ """
338
+ print(f"🚀 SeedVR2 Video Upscaler CLI iniciado às {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
339
+ args = parse_arguments()
340
+ try:
341
+ result_tensor, original_fps, _, _ = run_inference_logic(args)
342
+
343
+ if args.output is None:
344
+ args.output = f"result_{Path(args.video_path).stem}.mp4"
345
+
346
+ print(f"💾 Salvando vídeo em: {args.output}")
347
+ save_frames_to_video(result_tensor, args.output, original_fps, args.debug)
348
+ print("✅ Upscaling via CLI concluído com sucesso!")
349
+
350
+ except Exception as e:
351
+ print(f"❌ Erro durante o processamento via CLI: {e}")
352
+ import traceback
353
+ traceback.print_exc()
354
+ sys.exit(1)
355
+
356
+ if __name__ == "__main__":
357
+ main()
358
+ SeedVRServer.setup_dependencies()
359
+
360
 
361
 
362
 
 
376
  self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
377
 
378
 
379
+
380
+ #print("🚀 SeedVRServer ja inicializando...")
381
+
382
+ print("⚙️ SeedVRServer (Modo de Chamada Direta) inicializando...")
383
+ for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
384
+ p.mkdir(parents=True, exist_ok=True)
385
 
386
+ self.setup_dependencies()
387
+ print("📦 SeedVRServer (Modo de Chamada Direta) pronto.")
388
+ INIT = True
389
 
390
 
391
  def setup_dependencies(self):