File size: 5,512 Bytes
7dd33fe
 
a5720bf
 
7dd33fe
a5720bf
7dd33fe
a5720bf
c96bb7c
7dd33fe
 
a5720bf
 
7dd33fe
 
 
 
 
 
 
 
 
 
 
 
 
a5720bf
 
7dd33fe
 
 
 
 
a5720bf
 
 
7dd33fe
 
 
 
 
a5720bf
7dd33fe
a5720bf
7dd33fe
a5720bf
 
7dd33fe
a5720bf
 
 
 
7dd33fe
a5720bf
 
7dd33fe
a5720bf
 
 
 
7dd33fe
a5720bf
 
c96bb7c
 
 
7dd33fe
a5720bf
 
 
 
c96bb7c
7dd33fe
 
c96bb7c
a5720bf
7dd33fe
 
c96bb7c
7dd33fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c96bb7c
a5720bf
7dd33fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5720bf
7dd33fe
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# api/seedvr_server.py

import os
import sys
import shutil
import mimetypes
import time
from pathlib import Path
from typing import Optional, Tuple
from types import SimpleNamespace

from huggingface_hub import hf_hub_download

# Adiciona dinamicamente o caminho do repositório clonado ao sys.path
# Isso é crucial para que a importação do 'inference_cli' funcione.
SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
if str(SEEDVR_REPO_PATH) not in sys.path:
    sys.path.insert(0, str(SEEDVR_REPO_PATH))

# Tenta importar as funções necessárias APÓS a modificação do path.
try:
    from inference_cli import run_inference_logic, save_frames_to_video
except ImportError as e:
    print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'. Verifique se o repositório em {SEEDVR_REPO_PATH} está correto.")
    raise e

class SeedVRServer:
    def __init__(self, **kwargs):
        """
        Inicializa o servidor, define os caminhos e prepara o ambiente.
        """
        self.SEEDVR_ROOT = SEEDVR_REPO_PATH
        self.CKPTS_ROOT = Path("/data/seedvr_models_fp16") 
        self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
        self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
        self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
        self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
        self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
        
        print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
        for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
            p.mkdir(parents=True, exist_ok=True)
        
        self.setup_dependencies()
        print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")

    def setup_dependencies(self):
        """ Garante que o repositório e os modelos estão presentes. """
        self._ensure_repo()
        self._ensure_model()

    def _ensure_repo(self) -> None:
        """ Clona o repositório do SeedVR se ele não existir. """
        if not (self.SEEDVR_ROOT / ".git").exists():
            print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
            subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
        else:
            print("[SeedVRServer] Repositório SeedVR já existe.")

    def _ensure_model(self) -> None:
        """ Baixa os checkpoints do Hugging Face se não existirem localmente. """
        print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
        model_files = {
            "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
            "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
            "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B",
            "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
        }
        for filename, repo_id in model_files.items():
            if not (self.CKPTS_ROOT / filename).exists():
                print(f"Baixando {filename} de {repo_id}...")
                hf_hub_download(
                    repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
                    cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
                )
        print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
        
    def run_inference_direct(
        self,
        file_path: str, *,
        seed: int, res_h: int, res_w: int, sp_size: int,
        fps: Optional[float] = None, progress=None
    ) -> str:
        """
        Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
        """
        out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}"
        out_dir.mkdir(parents=True, exist_ok=True)
        output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"

        # Simula o objeto 'args' que a função de lógica espera
        args = SimpleNamespace(
            video_path=file_path,
            output=str(output_filepath),
            model_dir=str(self.CKPTS_ROOT),
            seed=seed,
            resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência
            batch_size=sp_size,
            model="seedvr2_ema_3b_fp16.safetensors",
            preserve_vram=True,
            debug=True,
            cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
            skip_first_frames=0,
            load_cap=0
        )

        try:
            if progress:
                progress(0.1, desc="Iniciando a lógica de inferência...")
            
            # Chama a função importada do script original
            result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
            
            if progress:
                progress(0.9, desc="Salvando o vídeo resultante...")

            final_fps = fps if fps and fps > 0 else original_fps
            save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
            
            print(f"✅ Vídeo salvo com sucesso em: {output_filepath}")
            
            return str(output_filepath)
            
        except Exception as e:
            print(f"❌ Erro durante a execução direta da inferência: {e}")
            import traceback
            traceback.print_exc()
            raise # Propaga o erro para a UI do Gradio, que o exibirá.