File size: 6,470 Bytes
98e3867
7dd33fe
a5720bf
 
98e3867
9ad5b00
98e3867
9ad5b00
a5720bf
98e3867
 
7dd33fe
98e3867
8e5d88b
9ad5b00
 
98e3867
 
9ad5b00
98e3867
8e5d88b
9ad5b00
 
98e3867
9ad5b00
98e3867
9ad5b00
 
98e3867
8e5d88b
98e3867
 
 
 
 
 
 
 
 
 
 
c2b6094
7dd33fe
9ad5b00
98e3867
 
8e5d88b
98e3867
9ad5b00
98e3867
 
 
 
 
 
 
 
 
 
9ad5b00
98e3867
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ad5b00
 
98e3867
 
 
 
 
7dd33fe
9ad5b00
7dd33fe
9ad5b00
98e3867
 
 
 
9ad5b00
 
 
 
 
 
 
 
 
a3faf27
9968866
9ad5b00
 
 
 
 
 
 
98e3867
9ad5b00
 
 
7dd33fe
9ad5b00
 
 
2991b2f
9ad5b00
 
 
 
 
98e3867
9ad5b00
98e3867
 
9ad5b00
 
98e3867
 
 
9ad5b00
98e3867
 
9ad5b00
5841247
bd9f380
9ad5b00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# api/seedvr_server.py

import os
import sys
import shutil
import mimetypes
import time
import subprocess # Necessário para clonar o repositório na configuração inicial
from pathlib import Path
from typing import Optional, Callable
from types import SimpleNamespace

from huggingface_hub import hf_hub_download

# Adiciona dinamicamente o caminho do repositório clonado ao sys.path.
# Isso é crucial para que a importação do 'inference_cli' funcione.
SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
if str(SEEDVR_REPO_PATH) not in sys.path:
    # Insere no início da lista para garantir prioridade de importação.
    sys.path.insert(0, str(SEEDVR_REPO_PATH))

# Tenta importar as funções necessárias APÓS a modificação do path.
# Se falhar, a aplicação não pode continuar.
try:
    from inference_cli import run_inference_logic, save_frames_to_video
except ImportError as e:
    print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'.")
    print(f"Verifique se o repositório em '{SEEDVR_REPO_PATH}' está correto e completo.")
    raise e

class SeedVRServer:
    def __init__(self, **kwargs):
        """
        Inicializa o servidor, define os caminhos e prepara o ambiente.
        """
        self.SEEDVR_ROOT = SEEDVR_REPO_PATH
        self.CKPTS_ROOT = Path("/data/seedvr_models_fp16") 
        self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
        self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
        self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
        self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
        self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
        
        print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
        for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
            p.mkdir(parents=True, exist_ok=True)
        
        self.setup_dependencies()
        print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")

    def setup_dependencies(self):
        """ Garante que o repositório e os modelos estão presentes. """
        self._ensure_repo()
        self._ensure_model()

    def _ensure_repo(self) -> None:
        """ Clona o repositório do SeedVR se ele não existir. """
        if not (self.SEEDVR_ROOT / ".git").exists():
            print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
            # Usamos subprocess.run aqui porque é uma tarefa de inicialização única.
            subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
        else:
            print("[SeedVRServer] Repositório SeedVR já existe.")

    def _ensure_model(self) -> None:
        """ Baixa os checkpoints do Hugging Face se não existirem localmente. """
        print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
        model_files = {
            "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
            "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
        }
        for filename, repo_id in model_files.items():
            if not (self.CKPTS_ROOT / filename).exists():
                print(f"Baixando {filename} de {repo_id}...")
                hf_hub_download(
                    repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
                    cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
                )
        print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
        
    def run_inference_direct(
        self,
        file_path: str, *,
        seed: int, res_h: int, res_w: int, sp_size: int,
        fps: Optional[float] = None, progress: Optional[Callable] = None
    ) -> str:
        """
        Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
        """
        # Cria um diretório de saída único para salvar o resultado.
        out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
        out_dir.mkdir(parents=True, exist_ok=True)
        output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"

        # Simula o objeto 'args' que a função de lógica do inference_cli espera.
        # Usamos SimpleNamespace para criar um objeto simples com atributos.
        args = SimpleNamespace(
            video_path=file_path,
            output=str(output_filepath),
            model_dir=str(self.CKPTS_ROOT),
            seed=seed,
            resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência.
            batch_size=sp_size,
            model="seedvr2_ema_7b_sharp_fp16.safetensors",
            preserve_vram=True,
            debug=True, # Mantém o debug ativo para logs detalhados.
            cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
            skip_first_frames=0,
            load_cap=0,
            output_format='video' # Garante que sempre gere vídeo
        )

        try:
            # Informa a UI que o processo começou.
            if progress:
                progress(0.01, "Initializing...")
            
            # Chama a função importada do script original, passando o callback de progresso.
            # Este callback será chamado de dentro da lógica de multi-processamento.
            result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
            
            # Informa a UI que a inferência terminou e o salvamento vai começar.
            if progress:
                progress(0.95, "Saving the final video...")

            # Define o FPS final: usa o valor da UI ou o original do vídeo de entrada.
            final_fps = fps if fps and fps > 0 else original_fps
            save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
            
            print(f"✅ Video saved successfully to: {output_filepath}")
            
            # Retorna o caminho do arquivo gerado para a UI.
            return str(output_filepath)
            
        except Exception as e:
            print(f"❌ Error during direct inference execution: {e}")
            import traceback
            traceback.print_exc()
            # Propaga o erro para a UI do Gradio, que o exibirá de forma amigável.
            raise