Update inference_cli.py
Browse files- inference_cli.py +14 -16
inference_cli.py
CHANGED
|
@@ -35,15 +35,15 @@ import cv2
|
|
| 35 |
import numpy as np
|
| 36 |
from datetime import datetime
|
| 37 |
from pathlib import Path
|
| 38 |
-
from src.utils.downloads import download_weight
|
| 39 |
|
| 40 |
# Adiciona o diretório raiz do projeto ao path do sistema para permitir importações de `src`
|
|
|
|
| 41 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 42 |
if script_dir not in sys.path:
|
| 43 |
sys.path.insert(0, script_dir)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load_cap=None):
|
| 49 |
"""
|
|
@@ -68,7 +68,7 @@ def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load
|
|
| 68 |
print(f"📊 Video info: {frame_count} frames, {width}x{height}, {fps:.2f} FPS")
|
| 69 |
if skip_first_frames:
|
| 70 |
print(f"⏭️ Will skip first {skip_first_frames} frames")
|
| 71 |
-
if load_cap:
|
| 72 |
print(f"🔢 Will load maximum {load_cap} frames")
|
| 73 |
|
| 74 |
frames = []
|
|
@@ -95,7 +95,7 @@ def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load
|
|
| 95 |
frames_loaded += 1
|
| 96 |
|
| 97 |
if debug and frames_loaded % 100 == 0:
|
| 98 |
-
total_to_load = min(frame_count, load_cap) if load_cap else frame_count
|
| 99 |
print(f"📹 Extracted {frames_loaded}/{total_to_load} frames")
|
| 100 |
|
| 101 |
cap.release()
|
|
@@ -211,39 +211,34 @@ def _gpu_processing(frames_tensor, device_list, args, progress_callback=None):
|
|
| 211 |
worker_progress = [0.0] * num_devices
|
| 212 |
|
| 213 |
while finished_workers_count < num_devices:
|
| 214 |
-
# 1. Processa todas as mensagens de progresso na fila
|
| 215 |
if progress_queue:
|
| 216 |
while not progress_queue.empty():
|
| 217 |
try:
|
| 218 |
proc_idx, batch_idx, total_batches, message = progress_queue.get_nowait()
|
| 219 |
-
if batch_idx == -1:
|
| 220 |
raise RuntimeError(f"Worker {proc_idx} error: {message}")
|
| 221 |
-
|
| 222 |
if total_batches > 0:
|
| 223 |
worker_progress[proc_idx] = batch_idx / total_batches
|
| 224 |
-
|
| 225 |
total_progress = sum(worker_progress) / num_devices
|
| 226 |
progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: {message}")
|
| 227 |
except queue.Empty:
|
| 228 |
break
|
| 229 |
|
| 230 |
-
# 2. Verifica se algum worker terminou
|
| 231 |
try:
|
| 232 |
-
proc_idx, result = return_queue.get(timeout=0.
|
| 233 |
if isinstance(result, str) and result.startswith("ERROR"):
|
| 234 |
raise RuntimeError(f"Worker {proc_idx} failed: {result}")
|
| 235 |
results_np[proc_idx] = result
|
| 236 |
-
worker_progress[proc_idx] = 1.0
|
| 237 |
finished_workers_count += 1
|
| 238 |
if progress_callback:
|
| 239 |
total_progress = sum(worker_progress) / num_devices
|
| 240 |
progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: Completed!")
|
| 241 |
except queue.Empty:
|
| 242 |
-
pass
|
| 243 |
|
| 244 |
for p in workers: p.join()
|
| 245 |
|
| 246 |
-
# Verifica se algum resultado está faltando, indicando um erro não capturado
|
| 247 |
if any(r is None for r in results_np):
|
| 248 |
raise RuntimeError("One or more workers failed to return a result.")
|
| 249 |
|
|
@@ -293,7 +288,7 @@ def run_inference_logic(args, progress_callback=None):
|
|
| 293 |
|
| 294 |
if progress_callback: progress_callback(0.1, "Starting generation...")
|
| 295 |
processing_start = time.time()
|
| 296 |
-
download_weight(args.model, args.model_dir)
|
| 297 |
|
| 298 |
result_tensor = _gpu_processing(frames_tensor, device_list, args, progress_callback)
|
| 299 |
|
|
@@ -313,6 +308,9 @@ def main():
|
|
| 313 |
try:
|
| 314 |
result_tensor, original_fps, _, _ = run_inference_logic(args)
|
| 315 |
|
|
|
|
|
|
|
|
|
|
| 316 |
print(f"💾 Salvando vídeo em: {args.output}")
|
| 317 |
save_frames_to_video(result_tensor, args.output, original_fps, args.debug)
|
| 318 |
print("✅ Upscaling via CLI concluído com sucesso!")
|
|
|
|
| 35 |
import numpy as np
|
| 36 |
from datetime import datetime
|
| 37 |
from pathlib import Path
|
|
|
|
| 38 |
|
| 39 |
# Adiciona o diretório raiz do projeto ao path do sistema para permitir importações de `src`
|
| 40 |
+
# Isso assume que o script está dentro do repositório clonado.
|
| 41 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 42 |
if script_dir not in sys.path:
|
| 43 |
sys.path.insert(0, script_dir)
|
| 44 |
+
|
| 45 |
+
# Importa as funções do SeedVR DEPOIS de ajustar o path.
|
| 46 |
+
from src.utils.downloads import download_weight
|
| 47 |
|
| 48 |
def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load_cap=None):
|
| 49 |
"""
|
|
|
|
| 68 |
print(f"📊 Video info: {frame_count} frames, {width}x{height}, {fps:.2f} FPS")
|
| 69 |
if skip_first_frames:
|
| 70 |
print(f"⏭️ Will skip first {skip_first_frames} frames")
|
| 71 |
+
if load_cap and load_cap > 0:
|
| 72 |
print(f"🔢 Will load maximum {load_cap} frames")
|
| 73 |
|
| 74 |
frames = []
|
|
|
|
| 95 |
frames_loaded += 1
|
| 96 |
|
| 97 |
if debug and frames_loaded % 100 == 0:
|
| 98 |
+
total_to_load = min(frame_count, load_cap) if (load_cap and load_cap > 0) else frame_count
|
| 99 |
print(f"📹 Extracted {frames_loaded}/{total_to_load} frames")
|
| 100 |
|
| 101 |
cap.release()
|
|
|
|
| 211 |
worker_progress = [0.0] * num_devices
|
| 212 |
|
| 213 |
while finished_workers_count < num_devices:
|
|
|
|
| 214 |
if progress_queue:
|
| 215 |
while not progress_queue.empty():
|
| 216 |
try:
|
| 217 |
proc_idx, batch_idx, total_batches, message = progress_queue.get_nowait()
|
| 218 |
+
if batch_idx == -1:
|
| 219 |
raise RuntimeError(f"Worker {proc_idx} error: {message}")
|
|
|
|
| 220 |
if total_batches > 0:
|
| 221 |
worker_progress[proc_idx] = batch_idx / total_batches
|
|
|
|
| 222 |
total_progress = sum(worker_progress) / num_devices
|
| 223 |
progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: {message}")
|
| 224 |
except queue.Empty:
|
| 225 |
break
|
| 226 |
|
|
|
|
| 227 |
try:
|
| 228 |
+
proc_idx, result = return_queue.get(timeout=0.2)
|
| 229 |
if isinstance(result, str) and result.startswith("ERROR"):
|
| 230 |
raise RuntimeError(f"Worker {proc_idx} failed: {result}")
|
| 231 |
results_np[proc_idx] = result
|
| 232 |
+
worker_progress[proc_idx] = 1.0
|
| 233 |
finished_workers_count += 1
|
| 234 |
if progress_callback:
|
| 235 |
total_progress = sum(worker_progress) / num_devices
|
| 236 |
progress_callback(total_progress, desc=f"GPU {proc_idx+1}/{num_devices}: Completed!")
|
| 237 |
except queue.Empty:
|
| 238 |
+
pass
|
| 239 |
|
| 240 |
for p in workers: p.join()
|
| 241 |
|
|
|
|
| 242 |
if any(r is None for r in results_np):
|
| 243 |
raise RuntimeError("One or more workers failed to return a result.")
|
| 244 |
|
|
|
|
| 288 |
|
| 289 |
if progress_callback: progress_callback(0.1, "Starting generation...")
|
| 290 |
processing_start = time.time()
|
| 291 |
+
download_weight(args.model, args.model_dir or "seedvr_models")
|
| 292 |
|
| 293 |
result_tensor = _gpu_processing(frames_tensor, device_list, args, progress_callback)
|
| 294 |
|
|
|
|
| 308 |
try:
|
| 309 |
result_tensor, original_fps, _, _ = run_inference_logic(args)
|
| 310 |
|
| 311 |
+
if args.output is None:
|
| 312 |
+
args.output = f"result_{Path(args.video_path).stem}.mp4"
|
| 313 |
+
|
| 314 |
print(f"💾 Salvando vídeo em: {args.output}")
|
| 315 |
save_frames_to_video(result_tensor, args.output, original_fps, args.debug)
|
| 316 |
print("✅ Upscaling via CLI concluído com sucesso!")
|