EuuIia commited on
Commit
76540a2
·
verified ·
1 Parent(s): 4ef720b

Update inference_cli.py

Browse files
Files changed (1) hide show
  1. inference_cli.py +74 -115
inference_cli.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
  Standalone SeedVR2 Video Upscaler CLI Script
4
- (MODIFICADO PARA SER IMPORTÁVEL E SUPORTAR CALLBACKS)
5
  """
6
 
7
  import sys
@@ -9,12 +9,14 @@ import os
9
  import argparse
10
  import time
11
  import multiprocessing as mp
 
 
12
  # Garante o uso seguro de CUDA com multiprocessing, essencial para estabilidade.
13
  if mp.get_start_method(allow_none=True) != 'spawn':
14
  mp.set_start_method('spawn', force=True)
15
 
16
  # -------------------------------------------------------------
17
- # 1) Configuração de alocação de memória da VRAM (essencial para performance)
18
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
19
 
20
  # 2) Pré-análise dos argumentos para configurar a visibilidade dos dispositivos CUDA
@@ -24,7 +26,6 @@ _pre_args, _ = _pre_parser.parse_known_args()
24
  if _pre_args.cuda_device is not None:
25
  device_list_env = [x.strip() for x in _pre_args.cuda_device.split(',') if x.strip()!='']
26
  if len(device_list_env) == 1:
27
- # Se apenas uma GPU for especificada, restringe a visibilidade do PyTorch a ela.
28
  os.environ["CUDA_VISIBLE_DEVICES"] = device_list_env[0]
29
 
30
  # -------------------------------------------------------------
@@ -48,101 +49,45 @@ def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load
48
  """
49
  Extrai quadros de um vídeo e os converte para o formato de tensor.
50
  """
51
- if debug:
52
- print(f"🎬 Extracting frames from video: {video_path}")
53
-
54
- if not os.path.exists(video_path):
55
- raise FileNotFoundError(f"Video file not found: {video_path}")
56
-
57
  cap = cv2.VideoCapture(video_path)
58
- if not cap.isOpened():
59
- raise ValueError(f"Cannot open video file: {video_path}")
60
-
61
  fps = cap.get(cv2.CAP_PROP_FPS)
62
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
63
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
64
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
65
-
66
- if debug:
67
- print(f"📊 Video info: {frame_count} frames, {width}x{height}, {fps:.2f} FPS")
68
- if skip_first_frames:
69
- print(f"⏭️ Will skip first {skip_first_frames} frames")
70
- if load_cap:
71
- print(f"🔢 Will load maximum {load_cap} frames")
72
-
73
- frames = []
74
- frame_idx = 0
75
- frames_loaded = 0
76
-
77
  while True:
78
  ret, frame = cap.read()
79
- if not ret:
80
- break
81
-
82
- if frame_idx < skip_first_frames:
83
- frame_idx += 1
84
- continue
85
-
86
- if load_cap is not None and load_cap > 0 and frames_loaded >= load_cap:
87
- break
88
-
89
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
90
  frame = frame.astype(np.float32) / 255.0
91
-
92
- frames.append(frame)
93
- frame_idx += 1
94
- frames_loaded += 1
95
-
96
- if debug and frames_loaded % 100 == 0:
97
- total_to_load = min(frame_count, load_cap) if load_cap else frame_count
98
- print(f"📹 Extracted {frames_loaded}/{total_to_load} frames")
99
-
100
  cap.release()
101
-
102
- if len(frames) == 0:
103
- raise ValueError(f"No frames extracted from video: {video_path}")
104
-
105
- if debug:
106
- print(f"✅ Extracted {len(frames)} frames")
107
-
108
  frames_tensor = torch.from_numpy(np.stack(frames)).to(torch.float16)
109
-
110
- if debug:
111
- print(f"📊 Frames tensor shape: {frames_tensor.shape}, dtype: {frames_tensor.dtype}")
112
-
113
  return frames_tensor, fps
114
 
115
-
116
  def save_frames_to_video(frames_tensor, output_path, fps=30.0, debug=False):
117
  """
118
  Salva um tensor de quadros em um arquivo de vídeo.
119
  """
120
- if debug:
121
- print(f"🎬 Saving {frames_tensor.shape[0]} frames to video: {output_path}")
122
-
123
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
124
-
125
  frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
126
-
127
  T, H, W, C = frames_np.shape
128
-
129
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
130
  out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
131
-
132
- if not out.isOpened():
133
- raise ValueError(f"Cannot create video writer for: {output_path}")
134
-
135
  for i, frame in enumerate(frames_np):
136
  frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
137
  out.write(frame_bgr)
138
-
139
- if debug and (i + 1) % 100 == 0:
140
- print(f"💾 Saved {i + 1}/{T} frames")
141
-
142
  out.release()
143
-
144
- if debug:
145
- print(f"✅ Video saved successfully: {output_path}")
146
 
147
  def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
148
  """
@@ -157,30 +102,33 @@ def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, p
157
 
158
  frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
159
 
160
- # Cria uma função de callback local que envia o progresso para a fila de comunicação
161
  local_progress_callback = None
162
  if progress_queue:
163
  def callback_wrapper(batch_idx, total_batches, current_frames, message):
164
- # Envia uma tupla com informações de progresso para a fila
165
  progress_queue.put((proc_idx, batch_idx, total_batches, message))
166
  local_progress_callback = callback_wrapper
167
 
168
- runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
169
-
170
- # Passa o callback local para o generation_loop, que sabe como usá-lo
171
- result_tensor = generation_loop(
172
- runner=runner, images=frames_tensor, cfg_scale=shared_args["cfg_scale"],
173
- seed=shared_args["seed"], res_w=shared_args["res_w"], batch_size=shared_args["batch_size"],
174
- preserve_vram=shared_args["preserve_vram"], temporal_overlap=shared_args["temporal_overlap"],
175
- debug=shared_args["debug"],
176
- progress_callback=local_progress_callback
177
- )
178
- # Envia o resultado final de volta para o processo pai
179
- return_queue.put((proc_idx, result_tensor.cpu().numpy()))
 
 
 
 
 
180
 
181
  def _gpu_processing(frames_tensor, device_list, args, progress_callback=None):
182
  """
183
- Divide os quadros entre as GPUs e gerencia os processos filhos, monitorando o progresso.
184
  """
185
  num_devices = len(device_list)
186
  chunks = torch.chunk(frames_tensor, num_devices, dim=0)
@@ -202,40 +150,54 @@ def _gpu_processing(frames_tensor, device_list, args, progress_callback=None):
202
  workers.append(p)
203
 
204
  results_np = [None] * num_devices
205
- collected_workers = 0
206
- worker_progress = [0.0] * num_devices # Armazena o progresso individual de cada worker
207
 
208
- while collected_workers < num_devices:
209
- # 1. Processa todas as mensagens de progresso na fila de forma não-bloqueante
210
  if progress_queue:
211
  while not progress_queue.empty():
212
- proc_idx, batch_idx, total_batches, message = progress_queue.get()
213
- if total_batches > 0:
214
- worker_progress[proc_idx] = batch_idx / total_batches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- # Calcula o progresso geral como a média do progresso de todos os workers
217
  total_progress = sum(worker_progress) / num_devices
218
-
219
- # Chama o callback principal (do Gradio) com a informação formatada
220
- progress_callback(total_progress, desc=f"GPU {proc_idx+1}: {message}")
221
-
222
- # 2. Verifica se algum worker terminou e enviou seu resultado
223
- if not return_queue.empty():
224
- proc_idx, res_np = return_queue.get()
225
- results_np[proc_idx] = res_np
226
- worker_progress[proc_idx] = 1.0 # Marca este worker como 100% concluído
227
- collected_workers += 1
228
 
229
- time.sleep(0.2) # Pequena pausa para evitar uso excessivo da CPU no loop
230
 
231
  for p in workers: p.join()
232
 
 
 
 
 
233
  return torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)
234
 
235
  def parse_arguments():
236
  """Analisa os argumentos da linha de comando."""
237
  parser = argparse.ArgumentParser(description="SeedVR2 Video Upscaler CLI")
238
-
239
  parser.add_argument("--video_path", type=str, required=True, help="Path to input video file")
240
  parser.add_argument("--seed", type=int, default=100, help="Random seed for generation (default: 100)")
241
  parser.add_argument("--resolution", type=int, default=1072, help="Target resolution of the short side (default: 1072)")
@@ -262,6 +224,7 @@ def run_inference_logic(args, progress_callback=None):
262
  if args.debug:
263
  print(f"📋 Argumentos da Lógica de Inferência: {vars(args)}")
264
 
 
265
  print("🎬 Extraindo frames do vídeo...")
266
  start_time = time.time()
267
  frames_tensor, original_fps = extract_frames_from_video(
@@ -274,10 +237,10 @@ def run_inference_logic(args, progress_callback=None):
274
  if args.debug:
275
  print(f"🚀 Usando dispositivos: {device_list}")
276
 
 
277
  processing_start = time.time()
278
  download_weight(args.model, args.model_dir)
279
 
280
- # Passa o callback para a função de processamento, que o gerenciará
281
  result_tensor = _gpu_processing(frames_tensor, device_list, args, progress_callback)
282
 
283
  generation_time = time.time() - processing_start
@@ -285,7 +248,6 @@ def run_inference_logic(args, progress_callback=None):
285
  print(f"🔄 Tempo de Geração: {generation_time:.2f}s")
286
  print(f"📊 Resultado: {result_tensor.shape}, dtype: {result_tensor.dtype}")
287
 
288
- # Retorna o tensor e metadados em memória para o chamador
289
  return result_tensor, original_fps, generation_time, len(frames_tensor)
290
 
291
  def main():
@@ -295,10 +257,8 @@ def main():
295
  print(f"🚀 SeedVR2 Video Upscaler CLI iniciado às {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
296
  args = parse_arguments()
297
  try:
298
- # Chama a função de lógica principal
299
  result_tensor, original_fps, _, _ = run_inference_logic(args)
300
 
301
- # Salva o resultado no disco, como esperado pelo modo CLI
302
  print(f"💾 Salvando vídeo em: {args.output}")
303
  save_frames_to_video(result_tensor, args.output, original_fps, args.debug)
304
  print("✅ Upscaling via CLI concluído com sucesso!")
@@ -309,6 +269,5 @@ def main():
309
  traceback.print_exc()
310
  sys.exit(1)
311
 
312
- # Ponto de entrada para quando o script é executado diretamente
313
  if __name__ == "__main__":
314
  main()
 
1
  #!/usr/bin/env python3
2
  """
3
  Standalone SeedVR2 Video Upscaler CLI Script
4
+ (MODIFICADO PARA SER IMPORTÁVEL E SUPORTAR CALLBACKS DE PROGRESSO EM MULTIPROCESSING)
5
  """
6
 
7
  import sys
 
9
  import argparse
10
  import time
11
  import multiprocessing as mp
12
+ import queue # Importa a classe de exceção para filas vazias
13
+
14
  # Garante o uso seguro de CUDA com multiprocessing, essencial para estabilidade.
15
  if mp.get_start_method(allow_none=True) != 'spawn':
16
  mp.set_start_method('spawn', force=True)
17
 
18
  # -------------------------------------------------------------
19
+ # 1) Configuração de alocação de memória da VRAM
20
  os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")
21
 
22
  # 2) Pré-análise dos argumentos para configurar a visibilidade dos dispositivos CUDA
 
26
  if _pre_args.cuda_device is not None:
27
  device_list_env = [x.strip() for x in _pre_args.cuda_device.split(',') if x.strip()!='']
28
  if len(device_list_env) == 1:
 
29
  os.environ["CUDA_VISIBLE_DEVICES"] = device_list_env[0]
30
 
31
  # -------------------------------------------------------------
 
49
  """
50
  Extrai quadros de um vídeo e os converte para o formato de tensor.
51
  """
52
+ if debug: print(f"🎬 Extracting frames from video: {video_path}")
53
+ if not os.path.exists(video_path): raise FileNotFoundError(f"Video file not found: {video_path}")
 
 
 
 
54
  cap = cv2.VideoCapture(video_path)
55
+ if not cap.isOpened(): raise ValueError(f"Cannot open video file: {video_path}")
 
 
56
  fps = cap.get(cv2.CAP_PROP_FPS)
57
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)); width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)); height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
58
+ if debug: print(f"📊 Video info: {frame_count} frames, {width}x{height}, {fps:.2f} FPS")
59
+ frames = []; frame_idx = 0; frames_loaded = 0
 
 
 
 
 
 
 
 
 
 
 
 
60
  while True:
61
  ret, frame = cap.read()
62
+ if not ret: break
63
+ if frame_idx < skip_first_frames: frame_idx += 1; continue
64
+ if load_cap is not None and load_cap > 0 and frames_loaded >= load_cap: break
 
 
 
 
 
 
 
65
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
66
  frame = frame.astype(np.float32) / 255.0
67
+ frames.append(frame); frame_idx += 1; frames_loaded += 1
 
 
 
 
 
 
 
 
68
  cap.release()
69
+ if len(frames) == 0: raise ValueError(f"No frames extracted from video: {video_path}")
70
+ if debug: print(f"✅ Extracted {len(frames)} frames")
 
 
 
 
 
71
  frames_tensor = torch.from_numpy(np.stack(frames)).to(torch.float16)
72
+ if debug: print(f"📊 Frames tensor shape: {frames_tensor.shape}, dtype: {frames_tensor.dtype}")
 
 
 
73
  return frames_tensor, fps
74
 
 
75
  def save_frames_to_video(frames_tensor, output_path, fps=30.0, debug=False):
76
  """
77
  Salva um tensor de quadros em um arquivo de vídeo.
78
  """
79
+ if debug: print(f"🎬 Saving {frames_tensor.shape[0]} frames to video: {output_path}")
 
 
80
  os.makedirs(os.path.dirname(output_path), exist_ok=True)
 
81
  frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
 
82
  T, H, W, C = frames_np.shape
 
83
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
84
  out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
85
+ if not out.isOpened(): raise ValueError(f"Cannot create video writer for: {output_path}")
 
 
 
86
  for i, frame in enumerate(frames_np):
87
  frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
88
  out.write(frame_bgr)
 
 
 
 
89
  out.release()
90
+ if debug: print(f"✅ Video saved successfully: {output_path}")
 
 
91
 
92
  def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
93
  """
 
102
 
103
  frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
104
 
 
105
  local_progress_callback = None
106
  if progress_queue:
107
  def callback_wrapper(batch_idx, total_batches, current_frames, message):
 
108
  progress_queue.put((proc_idx, batch_idx, total_batches, message))
109
  local_progress_callback = callback_wrapper
110
 
111
+ try:
112
+ runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])
113
+ result_tensor = generation_loop(
114
+ runner=runner, images=frames_tensor, cfg_scale=shared_args["cfg_scale"],
115
+ seed=shared_args["seed"], res_w=shared_args["res_w"], batch_size=shared_args["batch_size"],
116
+ preserve_vram=shared_args["preserve_vram"], temporal_overlap=shared_args["temporal_overlap"],
117
+ debug=shared_args["debug"],
118
+ progress_callback=local_progress_callback
119
+ )
120
+ return_queue.put((proc_idx, result_tensor.cpu().numpy()))
121
+ except Exception as e:
122
+ import traceback
123
+ error_msg = f"ERROR in worker {proc_idx}: {e}\n{traceback.format_exc()}"
124
+ print(error_msg)
125
+ if progress_queue:
126
+ progress_queue.put((proc_idx, -1, -1, error_msg))
127
+ return_queue.put((proc_idx, error_msg))
128
 
129
  def _gpu_processing(frames_tensor, device_list, args, progress_callback=None):
130
  """
131
+ Divide os quadros, gerencia os workers e monitora o progresso de forma robusta.
132
  """
133
  num_devices = len(device_list)
134
  chunks = torch.chunk(frames_tensor, num_devices, dim=0)
 
150
  workers.append(p)
151
 
152
  results_np = [None] * num_devices
153
+ finished_workers = [False] * num_devices
154
+ worker_progress = [0.0] * num_devices
155
 
156
+ while not all(finished_workers):
 
157
  if progress_queue:
158
  while not progress_queue.empty():
159
+ try:
160
+ proc_idx, batch_idx, total_batches, message = progress_queue.get_nowait()
161
+ if batch_idx == -1: # Mensagem de erro do worker
162
+ raise RuntimeError(f"Worker {proc_idx} encontrou um erro: {message}")
163
+
164
+ if total_batches > 0:
165
+ worker_progress[proc_idx] = batch_idx / total_batches
166
+
167
+ total_progress = sum(worker_progress) / num_devices
168
+ progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: {message}")
169
+ except queue.Empty:
170
+ break
171
+
172
+ while not return_queue.empty():
173
+ try:
174
+ proc_idx, result = return_queue.get_nowait()
175
+ if isinstance(result, str) and result.startswith("ERROR"):
176
+ raise RuntimeError(f"Worker {proc_idx} falhou: {result}")
177
+
178
+ results_np[proc_idx] = result
179
+ worker_progress[proc_idx] = 1.0 # Marca como 100% concluído
180
+ finished_workers[proc_idx] = True
181
 
 
182
  total_progress = sum(worker_progress) / num_devices
183
+ if progress_callback:
184
+ progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: Concluído!")
185
+ except queue.Empty:
186
+ break
 
 
 
 
 
 
187
 
188
+ time.sleep(0.2)
189
 
190
  for p in workers: p.join()
191
 
192
+ # Verifica se algum resultado está faltando, indicando um erro não capturado
193
+ if any(r is None for r in results_np):
194
+ raise RuntimeError("Um ou mais workers falharam em retornar um resultado.")
195
+
196
  return torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)
197
 
198
  def parse_arguments():
199
  """Analisa os argumentos da linha de comando."""
200
  parser = argparse.ArgumentParser(description="SeedVR2 Video Upscaler CLI")
 
201
  parser.add_argument("--video_path", type=str, required=True, help="Path to input video file")
202
  parser.add_argument("--seed", type=int, default=100, help="Random seed for generation (default: 100)")
203
  parser.add_argument("--resolution", type=int, default=1072, help="Target resolution of the short side (default: 1072)")
 
224
  if args.debug:
225
  print(f"📋 Argumentos da Lógica de Inferência: {vars(args)}")
226
 
227
+ if progress_callback: progress_callback(0.05, "Extracting frames...")
228
  print("🎬 Extraindo frames do vídeo...")
229
  start_time = time.time()
230
  frames_tensor, original_fps = extract_frames_from_video(
 
237
  if args.debug:
238
  print(f"🚀 Usando dispositivos: {device_list}")
239
 
240
+ if progress_callback: progress_callback(0.1, "Starting generation...")
241
  processing_start = time.time()
242
  download_weight(args.model, args.model_dir)
243
 
 
244
  result_tensor = _gpu_processing(frames_tensor, device_list, args, progress_callback)
245
 
246
  generation_time = time.time() - processing_start
 
248
  print(f"🔄 Tempo de Geração: {generation_time:.2f}s")
249
  print(f"📊 Resultado: {result_tensor.shape}, dtype: {result_tensor.dtype}")
250
 
 
251
  return result_tensor, original_fps, generation_time, len(frames_tensor)
252
 
253
  def main():
 
257
  print(f"🚀 SeedVR2 Video Upscaler CLI iniciado às {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
258
  args = parse_arguments()
259
  try:
 
260
  result_tensor, original_fps, _, _ = run_inference_logic(args)
261
 
 
262
  print(f"💾 Salvando vídeo em: {args.output}")
263
  save_frames_to_video(result_tensor, args.output, original_fps, args.debug)
264
  print("✅ Upscaling via CLI concluído com sucesso!")
 
269
  traceback.print_exc()
270
  sys.exit(1)
271
 
 
272
  if __name__ == "__main__":
273
  main()