EuuIia commited on
Commit
664e50f
·
verified ·
1 Parent(s): 60281a6

Update inference_cli.py

Browse files
Files changed (1) hide show
  1. inference_cli.py +14 -16
inference_cli.py CHANGED
@@ -35,15 +35,15 @@ import cv2
35
  import numpy as np
36
  from datetime import datetime
37
  from pathlib import Path
38
- from src.utils.downloads import download_weight
39
 
40
  # Adiciona o diretório raiz do projeto ao path do sistema para permitir importações de `src`
 
41
  script_dir = os.path.dirname(os.path.abspath(__file__))
42
  if script_dir not in sys.path:
43
  sys.path.insert(0, script_dir)
44
- root_dir = os.path.join(script_dir, '..', '..')
45
- if root_dir not in sys.path:
46
- sys.path.insert(0, root_dir)
47
 
48
  def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load_cap=None):
49
  """
@@ -68,7 +68,7 @@ def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load
68
  print(f"📊 Video info: {frame_count} frames, {width}x{height}, {fps:.2f} FPS")
69
  if skip_first_frames:
70
  print(f"⏭️ Will skip first {skip_first_frames} frames")
71
- if load_cap:
72
  print(f"🔢 Will load maximum {load_cap} frames")
73
 
74
  frames = []
@@ -95,7 +95,7 @@ def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load
95
  frames_loaded += 1
96
 
97
  if debug and frames_loaded % 100 == 0:
98
- total_to_load = min(frame_count, load_cap) if load_cap else frame_count
99
  print(f"📹 Extracted {frames_loaded}/{total_to_load} frames")
100
 
101
  cap.release()
@@ -211,39 +211,34 @@ def _gpu_processing(frames_tensor, device_list, args, progress_callback=None):
211
  worker_progress = [0.0] * num_devices
212
 
213
  while finished_workers_count < num_devices:
214
- # 1. Processa todas as mensagens de progresso na fila
215
  if progress_queue:
216
  while not progress_queue.empty():
217
  try:
218
  proc_idx, batch_idx, total_batches, message = progress_queue.get_nowait()
219
- if batch_idx == -1: # Mensagem de erro do worker
220
  raise RuntimeError(f"Worker {proc_idx} error: {message}")
221
-
222
  if total_batches > 0:
223
  worker_progress[proc_idx] = batch_idx / total_batches
224
-
225
  total_progress = sum(worker_progress) / num_devices
226
  progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: {message}")
227
  except queue.Empty:
228
  break
229
 
230
- # 2. Verifica se algum worker terminou
231
  try:
232
- proc_idx, result = return_queue.get(timeout=0.1) # Usa um timeout curto
233
  if isinstance(result, str) and result.startswith("ERROR"):
234
  raise RuntimeError(f"Worker {proc_idx} failed: {result}")
235
  results_np[proc_idx] = result
236
- worker_progress[proc_idx] = 1.0 # Marca como 100%
237
  finished_workers_count += 1
238
  if progress_callback:
239
  total_progress = sum(worker_progress) / num_devices
240
  progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: Completed!")
241
  except queue.Empty:
242
- pass # Continua o loop se não houver resultados ainda
243
 
244
  for p in workers: p.join()
245
 
246
- # Verifica se algum resultado está faltando, indicando um erro não capturado
247
  if any(r is None for r in results_np):
248
  raise RuntimeError("One or more workers failed to return a result.")
249
 
@@ -293,7 +288,7 @@ def run_inference_logic(args, progress_callback=None):
293
 
294
  if progress_callback: progress_callback(0.1, "Starting generation...")
295
  processing_start = time.time()
296
- download_weight(args.model, args.model_dir)
297
 
298
  result_tensor = _gpu_processing(frames_tensor, device_list, args, progress_callback)
299
 
@@ -313,6 +308,9 @@ def main():
313
  try:
314
  result_tensor, original_fps, _, _ = run_inference_logic(args)
315
 
 
 
 
316
  print(f"💾 Salvando vídeo em: {args.output}")
317
  save_frames_to_video(result_tensor, args.output, original_fps, args.debug)
318
  print("✅ Upscaling via CLI concluído com sucesso!")
 
35
  import numpy as np
36
  from datetime import datetime
37
  from pathlib import Path
 
38
 
39
  # Adiciona o diretório raiz do projeto ao path do sistema para permitir importações de `src`
40
+ # Isso assume que o script está dentro do repositório clonado.
41
  script_dir = os.path.dirname(os.path.abspath(__file__))
42
  if script_dir not in sys.path:
43
  sys.path.insert(0, script_dir)
44
+
45
+ # Importa as funções do SeedVR DEPOIS de ajustar o path.
46
+ from src.utils.downloads import download_weight
47
 
48
  def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load_cap=None):
49
  """
 
68
  print(f"📊 Video info: {frame_count} frames, {width}x{height}, {fps:.2f} FPS")
69
  if skip_first_frames:
70
  print(f"⏭️ Will skip first {skip_first_frames} frames")
71
+ if load_cap and load_cap > 0:
72
  print(f"🔢 Will load maximum {load_cap} frames")
73
 
74
  frames = []
 
95
  frames_loaded += 1
96
 
97
  if debug and frames_loaded % 100 == 0:
98
+ total_to_load = min(frame_count, load_cap) if (load_cap and load_cap > 0) else frame_count
99
  print(f"📹 Extracted {frames_loaded}/{total_to_load} frames")
100
 
101
  cap.release()
 
211
  worker_progress = [0.0] * num_devices
212
 
213
  while finished_workers_count < num_devices:
 
214
  if progress_queue:
215
  while not progress_queue.empty():
216
  try:
217
  proc_idx, batch_idx, total_batches, message = progress_queue.get_nowait()
218
+ if batch_idx == -1:
219
  raise RuntimeError(f"Worker {proc_idx} error: {message}")
 
220
  if total_batches > 0:
221
  worker_progress[proc_idx] = batch_idx / total_batches
 
222
  total_progress = sum(worker_progress) / num_devices
223
  progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: {message}")
224
  except queue.Empty:
225
  break
226
 
 
227
  try:
228
+ proc_idx, result = return_queue.get(timeout=0.2)
229
  if isinstance(result, str) and result.startswith("ERROR"):
230
  raise RuntimeError(f"Worker {proc_idx} failed: {result}")
231
  results_np[proc_idx] = result
232
+ worker_progress[proc_idx] = 1.0
233
  finished_workers_count += 1
234
  if progress_callback:
235
  total_progress = sum(worker_progress) / num_devices
236
  progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: Completed!")
237
  except queue.Empty:
238
+ pass
239
 
240
  for p in workers: p.join()
241
 
 
242
  if any(r is None for r in results_np):
243
  raise RuntimeError("One or more workers failed to return a result.")
244
 
 
288
 
289
  if progress_callback: progress_callback(0.1, "Starting generation...")
290
  processing_start = time.time()
291
+ download_weight(args.model, args.model_dir or "seedvr_models")
292
 
293
  result_tensor = _gpu_processing(frames_tensor, device_list, args, progress_callback)
294
 
 
308
  try:
309
  result_tensor, original_fps, _, _ = run_inference_logic(args)
310
 
311
+ if args.output is None:
312
+ args.output = f"result_{Path(args.video_path).stem}.mp4"
313
+
314
  print(f"💾 Salvando vídeo em: {args.output}")
315
  save_frames_to_video(result_tensor, args.output, original_fps, args.debug)
316
  print("✅ Upscaling via CLI concluído com sucesso!")