File size: 13,383 Bytes
aa84ba3
 
 
46bb424
aa84ba3
 
 
 
 
 
 
46bb424
aa84ba3
 
46bb424
aa84ba3
46bb424
aa84ba3
 
46bb424
aa84ba3
 
 
 
 
 
46bb424
aa84ba3
 
 
46bb424
aa84ba3
 
 
 
 
 
 
46bb424
aa84ba3
 
 
 
 
 
 
 
 
46bb424
aa84ba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bb424
aa84ba3
 
 
 
 
 
46bb424
aa84ba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bb424
aa84ba3
46bb424
aa84ba3
d411dcc
 
 
 
 
 
 
 
 
46bb424
d411dcc
 
 
46bb424
 
d411dcc
 
 
 
46bb424
d411dcc
 
 
 
 
46bb424
d411dcc
46bb424
d411dcc
aa84ba3
46bb424
 
 
 
d411dcc
 
 
 
 
46bb424
d411dcc
 
 
 
 
 
 
 
 
 
 
 
 
 
46bb424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d411dcc
 
 
46bb424
 
 
 
d411dcc
 
 
 
aa84ba3
 
46bb424
aa84ba3
 
46bb424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa84ba3
 
 
 
 
46bb424
aa84ba3
 
46bb424
aa84ba3
 
 
 
 
 
 
 
46bb424
aa84ba3
 
 
 
 
 
 
46bb424
 
aa84ba3
 
 
 
 
 
46bb424
aa84ba3
 
 
46bb424
 
 
aa84ba3
 
 
46bb424
aa84ba3
46bb424
 
aa84ba3
 
46bb424
aa84ba3
 
46bb424
aa84ba3
 
 
 
46bb424
aa84ba3
46bb424
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
#!/usr/bin/env python3
"""
Standalone SeedVR2 Video Upscaler CLI Script
(MODIFICADO PARA SER IMPORTÁVEL E SUPORTAR CALLBACKS)
"""

import sys
import os
import argparse
import time
import multiprocessing as mp
# Garante o uso seguro de CUDA com multiprocessing, essencial para estabilidade.
if mp.get_start_method(allow_none=True) != 'spawn':
    mp.set_start_method('spawn', force=True)

# -------------------------------------------------------------
# 1) Configuração de alocação de memória da VRAM (essencial para performance)
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")

# 2) Pré-análise dos argumentos para configurar a visibilidade dos dispositivos CUDA
_pre_parser = argparse.ArgumentParser(add_help=False)
_pre_parser.add_argument("--cuda_device", type=str, default=None)
_pre_args, _ = _pre_parser.parse_known_args()
if _pre_args.cuda_device is not None:
    device_list_env = [x.strip() for x in _pre_args.cuda_device.split(',') if x.strip()!='']
    if len(device_list_env) == 1:
        # Se apenas uma GPU for especificada, restringe a visibilidade do PyTorch a ela.
        os.environ["CUDA_VISIBLE_DEVICES"] = device_list_env[0]

# -------------------------------------------------------------
# 3) Importações pesadas (torch, etc.) são feitas após a configuração do ambiente.
import torch
import cv2
import numpy as np
from datetime import datetime
from pathlib import Path
from src.utils.downloads import download_weight

# Adiciona o diretório raiz do projeto ao path do sistema para permitir importações de `src`
script_dir = os.path.dirname(os.path.abspath(__file__))
if script_dir not in sys.path:
    sys.path.insert(0, script_dir)
root_dir = os.path.join(script_dir, '..', '..')
if root_dir not in sys.path:
    sys.path.insert(0, root_dir)

def extract_frames_from_video(video_path, debug=False, skip_first_frames=0, load_cap=None):
    """
    Extrai quadros de um vídeo e os converte para o formato de tensor.
    """
    if debug:
        print(f"🎬 Extracting frames from video: {video_path}")
    
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"Video file not found: {video_path}")
    
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Cannot open video file: {video_path}")
    
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    if debug:
        print(f"📊 Video info: {frame_count} frames, {width}x{height}, {fps:.2f} FPS")
        if skip_first_frames:
            print(f"⏭️ Will skip first {skip_first_frames} frames")
        if load_cap:
            print(f"🔢 Will load maximum {load_cap} frames")
    
    frames = []
    frame_idx = 0
    frames_loaded = 0
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        if frame_idx < skip_first_frames:
            frame_idx += 1
            continue
        
        if load_cap is not None and load_cap > 0 and frames_loaded >= load_cap:
            break
        
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = frame.astype(np.float32) / 255.0
        
        frames.append(frame)
        frame_idx += 1
        frames_loaded += 1
        
        if debug and frames_loaded % 100 == 0:
            total_to_load = min(frame_count, load_cap) if load_cap else frame_count
            print(f"📹 Extracted {frames_loaded}/{total_to_load} frames")
    
    cap.release()
    
    if len(frames) == 0:
        raise ValueError(f"No frames extracted from video: {video_path}")
    
    if debug:
        print(f"✅ Extracted {len(frames)} frames")
    
    frames_tensor = torch.from_numpy(np.stack(frames)).to(torch.float16)
    
    if debug:
        print(f"📊 Frames tensor shape: {frames_tensor.shape}, dtype: {frames_tensor.dtype}")
    
    return frames_tensor, fps


def save_frames_to_video(frames_tensor, output_path, fps=30.0, debug=False):
    """
    Salva um tensor de quadros em um arquivo de vídeo.
    """
    if debug:
        print(f"🎬 Saving {frames_tensor.shape[0]} frames to video: {output_path}")
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    frames_np = (frames_tensor.cpu().numpy() * 255.0).astype(np.uint8)
    
    T, H, W, C = frames_np.shape
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (W, H))
    
    if not out.isOpened():
        raise ValueError(f"Cannot create video writer for: {output_path}")
    
    for i, frame in enumerate(frames_np):
        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        out.write(frame_bgr)
        
        if debug and (i + 1) % 100 == 0:
            print(f"💾 Saved {i + 1}/{T} frames")
    
    out.release()
    
    if debug:
        print(f"✅ Video saved successfully: {output_path}")

def _worker_process(proc_idx, device_id, frames_np, shared_args, return_queue, progress_queue=None):
    """
    Processo filho (worker) que executa o upscaling em uma GPU dedicada.
    """
    os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
    os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "backend:cudaMallocAsync")

    import torch
    from src.core.model_manager import configure_runner
    from src.core.generation import generation_loop
    
    frames_tensor = torch.from_numpy(frames_np).to(torch.float16)
    
    # Cria uma função de callback local que envia o progresso para a fila de comunicação
    local_progress_callback = None
    if progress_queue:
        def callback_wrapper(batch_idx, total_batches, current_frames, message):
            # Envia uma tupla com informações de progresso para a fila
            progress_queue.put((proc_idx, batch_idx, total_batches, message))
        local_progress_callback = callback_wrapper

    runner = configure_runner(shared_args["model"], shared_args["model_dir"], shared_args["preserve_vram"], shared_args["debug"])

    # Passa o callback local para o generation_loop, que sabe como usá-lo
    result_tensor = generation_loop(
        runner=runner, images=frames_tensor, cfg_scale=shared_args["cfg_scale"],
        seed=shared_args["seed"], res_w=shared_args["res_w"], batch_size=shared_args["batch_size"],
        preserve_vram=shared_args["preserve_vram"], temporal_overlap=shared_args["temporal_overlap"],
        debug=shared_args["debug"],
        progress_callback=local_progress_callback 
    )
    # Envia o resultado final de volta para o processo pai
    return_queue.put((proc_idx, result_tensor.cpu().numpy()))

def _gpu_processing(frames_tensor, device_list, args, progress_callback=None):
    """
    Divide os quadros entre as GPUs e gerencia os processos filhos, monitorando o progresso.
    """
    num_devices = len(device_list)
    chunks = torch.chunk(frames_tensor, num_devices, dim=0)

    manager = mp.Manager()
    return_queue = manager.Queue()
    progress_queue = manager.Queue() if progress_callback else None
    workers = []

    shared_args = {
        "model": args.model, "model_dir": args.model_dir or "./models/SEEDVR2",
        "preserve_vram": args.preserve_vram, "debug": args.debug, "cfg_scale": 1.0,
        "seed": args.seed, "res_w": args.resolution, "batch_size": args.batch_size, "temporal_overlap": 0,
    }

    for idx, (device_id, chunk_tensor) in enumerate(zip(device_list, chunks)):
        p = mp.Process(target=_worker_process, args=(idx, device_id, chunk_tensor.cpu().numpy(), shared_args, return_queue, progress_queue))
        p.start()
        workers.append(p)

    results_np = [None] * num_devices
    collected_workers = 0
    worker_progress = [0.0] * num_devices # Armazena o progresso individual de cada worker
    
    while collected_workers < num_devices:
        # 1. Processa todas as mensagens de progresso na fila de forma não-bloqueante
        if progress_queue:
            while not progress_queue.empty():
                proc_idx, batch_idx, total_batches, message = progress_queue.get()
                if total_batches > 0:
                    worker_progress[proc_idx] = batch_idx / total_batches
                
                # Calcula o progresso geral como a média do progresso de todos os workers
                total_progress = sum(worker_progress) / num_devices
                
                # Chama o callback principal (do Gradio) com a informação formatada
                progress_callback(total_progress, desc=f"GPU {proc_idx+1}: {message}")

        # 2. Verifica se algum worker terminou e enviou seu resultado
        if not return_queue.empty():
            proc_idx, res_np = return_queue.get()
            results_np[proc_idx] = res_np
            worker_progress[proc_idx] = 1.0 # Marca este worker como 100% concluído
            collected_workers += 1

        time.sleep(0.2) # Pequena pausa para evitar uso excessivo da CPU no loop

    for p in workers: p.join()

    return torch.from_numpy(np.concatenate(results_np, axis=0)).to(torch.float16)

def parse_arguments():
    """Analisa os argumentos da linha de comando."""
    parser = argparse.ArgumentParser(description="SeedVR2 Video Upscaler CLI")
    
    parser.add_argument("--video_path", type=str, required=True, help="Path to input video file")
    parser.add_argument("--seed", type=int, default=100, help="Random seed for generation (default: 100)")
    parser.add_argument("--resolution", type=int, default=1072, help="Target resolution of the short side (default: 1072)")
    parser.add_argument("--batch_size", type=int, default=5, help="Number of frames per batch (default: 5)")
    parser.add_argument("--model", type=str, default="seedvr2_ema_3b_fp16.safetensors",
                        choices=["seedvr2_ema_3b_fp16.safetensors", "seedvr2_ema_3b_fp8_e4m3fn.safetensors", 
                                 "seedvr2_ema_7b_fp16.safetensors", "seedvr2_ema_7b_fp8_e4m3fn.safetensors"],
                        help="Model to use")
    parser.add_argument("--model_dir", type=str, default=None, help="Directory containing the model files")
    parser.add_argument("--skip_first_frames", type=int, default=0, help="Skip the first frames during processing")
    parser.add_argument("--load_cap", type=int, default=0, help="Maximum number of frames to load from video (default: load all)")
    parser.add_argument("--output", type=str, default=None, help="Output path")
    parser.add_argument("--output_format", type=str, default="video", choices=["video", "png"], help="Output format: 'video' (mp4) or 'png' images")
    parser.add_argument("--preserve_vram", action="store_true", help="Enable VRAM preservation mode")
    parser.add_argument("--debug", action="store_true", help="Enable debug logging")
    parser.add_argument("--cuda_device", type=str, default=None, help="CUDA device id(s). e.g., '0' or '0,1' for multi-GPU")
    
    return parser.parse_args()

def run_inference_logic(args, progress_callback=None):
    """
    Função principal que executa o pipeline de upscaling. Pode ser importada e chamada por outros scripts.
    """
    if args.debug:
        print(f"📋 Argumentos da Lógica de Inferência: {vars(args)}")

    print("🎬 Extraindo frames do vídeo...")
    start_time = time.time()
    frames_tensor, original_fps = extract_frames_from_video(
        args.video_path, args.debug, args.skip_first_frames, args.load_cap
    )
    if args.debug:
        print(f"🔄 Tempo de extração de frames: {time.time() - start_time:.2f}s")
        
    device_list = [d.strip() for d in str(args.cuda_device).split(',') if d.strip()] if args.cuda_device else ["0"]
    if args.debug:
        print(f"🚀 Usando dispositivos: {device_list}")
    
    processing_start = time.time()
    download_weight(args.model, args.model_dir)
    
    # Passa o callback para a função de processamento, que o gerenciará
    result_tensor = _gpu_processing(frames_tensor, device_list, args, progress_callback)

    generation_time = time.time() - processing_start
    if args.debug:
        print(f"🔄 Tempo de Geração: {generation_time:.2f}s")
        print(f"📊 Resultado: {result_tensor.shape}, dtype: {result_tensor.dtype}")
        
    # Retorna o tensor e metadados em memória para o chamador
    return result_tensor, original_fps, generation_time, len(frames_tensor)

def main():
    """
    Função principal para execução via linha de comando (CLI).
    """
    print(f"🚀 SeedVR2 Video Upscaler CLI iniciado às {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    args = parse_arguments()
    try:
        # Chama a função de lógica principal
        result_tensor, original_fps, _, _ = run_inference_logic(args)
        
        # Salva o resultado no disco, como esperado pelo modo CLI
        print(f"💾 Salvando vídeo em: {args.output}")
        save_frames_to_video(result_tensor, args.output, original_fps, args.debug)
        print("✅ Upscaling via CLI concluído com sucesso!")
        
    except Exception as e:
        print(f"❌ Erro durante o processamento via CLI: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

# Ponto de entrada para quando o script é executado diretamente
if __name__ == "__main__":
    main()