Test2 / api /seedvr_server.py
EuuIia's picture
Update api/seedvr_server.py
9968866 verified
raw
history blame
6.47 kB
# api/seedvr_server.py
import os
import sys
import shutil
import mimetypes
import time
import subprocess # Necessário para clonar o repositório na configuração inicial
from pathlib import Path
from typing import Optional, Callable
from types import SimpleNamespace
from huggingface_hub import hf_hub_download
# Adiciona dinamicamente o caminho do repositório clonado ao sys.path.
# Isso é crucial para que a importação do 'inference_cli' funcione.
SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
if str(SEEDVR_REPO_PATH) not in sys.path:
# Insere no início da lista para garantir prioridade de importação.
sys.path.insert(0, str(SEEDVR_REPO_PATH))
# Tenta importar as funções necessárias APÓS a modificação do path.
# Se falhar, a aplicação não pode continuar.
try:
from inference_cli import run_inference_logic, save_frames_to_video
except ImportError as e:
print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'.")
print(f"Verifique se o repositório em '{SEEDVR_REPO_PATH}' está correto e completo.")
raise e
class SeedVRServer:
def __init__(self, **kwargs):
"""
Inicializa o servidor, define os caminhos e prepara o ambiente.
"""
self.SEEDVR_ROOT = SEEDVR_REPO_PATH
self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
p.mkdir(parents=True, exist_ok=True)
self.setup_dependencies()
print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")
def setup_dependencies(self):
""" Garante que o repositório e os modelos estão presentes. """
self._ensure_repo()
self._ensure_model()
def _ensure_repo(self) -> None:
""" Clona o repositório do SeedVR se ele não existir. """
if not (self.SEEDVR_ROOT / ".git").exists():
print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
# Usamos subprocess.run aqui porque é uma tarefa de inicialização única.
subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
else:
print("[SeedVRServer] Repositório SeedVR já existe.")
def _ensure_model(self) -> None:
""" Baixa os checkpoints do Hugging Face se não existirem localmente. """
print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
model_files = {
"seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
"ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
}
for filename, repo_id in model_files.items():
if not (self.CKPTS_ROOT / filename).exists():
print(f"Baixando {filename} de {repo_id}...")
hf_hub_download(
repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
)
print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
def run_inference_direct(
self,
file_path: str, *,
seed: int, res_h: int, res_w: int, sp_size: int,
fps: Optional[float] = None, progress: Optional[Callable] = None
) -> str:
"""
Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
"""
# Cria um diretório de saída único para salvar o resultado.
out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
out_dir.mkdir(parents=True, exist_ok=True)
output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
# Simula o objeto 'args' que a função de lógica do inference_cli espera.
# Usamos SimpleNamespace para criar um objeto simples com atributos.
args = SimpleNamespace(
video_path=file_path,
output=str(output_filepath),
model_dir=str(self.CKPTS_ROOT),
seed=seed,
resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência.
batch_size=sp_size,
model="seedvr2_ema_7b_sharp_fp16.safetensors",
preserve_vram=True,
debug=True, # Mantém o debug ativo para logs detalhados.
cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
skip_first_frames=0,
load_cap=0,
output_format='video' # Garante que sempre gere vídeo
)
try:
# Informa a UI que o processo começou.
if progress:
progress(0.01, "Initializing...")
# Chama a função importada do script original, passando o callback de progresso.
# Este callback será chamado de dentro da lógica de multi-processamento.
result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
# Informa a UI que a inferência terminou e o salvamento vai começar.
if progress:
progress(0.95, "Saving the final video...")
# Define o FPS final: usa o valor da UI ou o original do vídeo de entrada.
final_fps = fps if fps and fps > 0 else original_fps
save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
print(f"✅ Video saved successfully to: {output_filepath}")
# Retorna o caminho do arquivo gerado para a UI.
return str(output_filepath)
except Exception as e:
print(f"❌ Error during direct inference execution: {e}")
import traceback
traceback.print_exc()
# Propaga o erro para a UI do Gradio, que o exibirá de forma amigável.
raise