File size: 6,470 Bytes
98e3867 7dd33fe a5720bf 98e3867 9ad5b00 98e3867 9ad5b00 a5720bf 98e3867 7dd33fe 98e3867 8e5d88b 9ad5b00 98e3867 9ad5b00 98e3867 8e5d88b 9ad5b00 98e3867 9ad5b00 98e3867 9ad5b00 98e3867 8e5d88b 98e3867 c2b6094 7dd33fe 9ad5b00 98e3867 8e5d88b 98e3867 9ad5b00 98e3867 9ad5b00 98e3867 9ad5b00 98e3867 7dd33fe 9ad5b00 7dd33fe 9ad5b00 98e3867 9ad5b00 a3faf27 9968866 9ad5b00 98e3867 9ad5b00 7dd33fe 9ad5b00 2991b2f 9ad5b00 98e3867 9ad5b00 98e3867 9ad5b00 98e3867 9ad5b00 98e3867 9ad5b00 5841247 bd9f380 9ad5b00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# api/seedvr_server.py
import os
import sys
import shutil
import mimetypes
import time
import subprocess # Necessário para clonar o repositório na configuração inicial
from pathlib import Path
from typing import Optional, Callable
from types import SimpleNamespace
from huggingface_hub import hf_hub_download
# Adiciona dinamicamente o caminho do repositório clonado ao sys.path.
# Isso é crucial para que a importação do 'inference_cli' funcione.
SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
if str(SEEDVR_REPO_PATH) not in sys.path:
# Insere no início da lista para garantir prioridade de importação.
sys.path.insert(0, str(SEEDVR_REPO_PATH))
# Tenta importar as funções necessárias APÓS a modificação do path.
# Se falhar, a aplicação não pode continuar.
try:
from inference_cli import run_inference_logic, save_frames_to_video
except ImportError as e:
print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'.")
print(f"Verifique se o repositório em '{SEEDVR_REPO_PATH}' está correto e completo.")
raise e
class SeedVRServer:
def __init__(self, **kwargs):
"""
Inicializa o servidor, define os caminhos e prepara o ambiente.
"""
self.SEEDVR_ROOT = SEEDVR_REPO_PATH
self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
p.mkdir(parents=True, exist_ok=True)
self.setup_dependencies()
print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")
def setup_dependencies(self):
""" Garante que o repositório e os modelos estão presentes. """
self._ensure_repo()
self._ensure_model()
def _ensure_repo(self) -> None:
""" Clona o repositório do SeedVR se ele não existir. """
if not (self.SEEDVR_ROOT / ".git").exists():
print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
# Usamos subprocess.run aqui porque é uma tarefa de inicialização única.
subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
else:
print("[SeedVRServer] Repositório SeedVR já existe.")
def _ensure_model(self) -> None:
""" Baixa os checkpoints do Hugging Face se não existirem localmente. """
print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
model_files = {
"seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
"ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
}
for filename, repo_id in model_files.items():
if not (self.CKPTS_ROOT / filename).exists():
print(f"Baixando {filename} de {repo_id}...")
hf_hub_download(
repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
)
print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
def run_inference_direct(
self,
file_path: str, *,
seed: int, res_h: int, res_w: int, sp_size: int,
fps: Optional[float] = None, progress: Optional[Callable] = None
) -> str:
"""
Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
"""
# Cria um diretório de saída único para salvar o resultado.
out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}_{Path(file_path).stem}"
out_dir.mkdir(parents=True, exist_ok=True)
output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
# Simula o objeto 'args' que a função de lógica do inference_cli espera.
# Usamos SimpleNamespace para criar um objeto simples com atributos.
args = SimpleNamespace(
video_path=file_path,
output=str(output_filepath),
model_dir=str(self.CKPTS_ROOT),
seed=seed,
resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência.
batch_size=sp_size,
model="seedvr2_ema_7b_sharp_fp16.safetensors",
preserve_vram=True,
debug=True, # Mantém o debug ativo para logs detalhados.
cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
skip_first_frames=0,
load_cap=0,
output_format='video' # Garante que sempre gere vídeo
)
try:
# Informa a UI que o processo começou.
if progress:
progress(0.01, "Initializing...")
# Chama a função importada do script original, passando o callback de progresso.
# Este callback será chamado de dentro da lógica de multi-processamento.
result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
# Informa a UI que a inferência terminou e o salvamento vai começar.
if progress:
progress(0.95, "Saving the final video...")
# Define o FPS final: usa o valor da UI ou o original do vídeo de entrada.
final_fps = fps if fps and fps > 0 else original_fps
save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
print(f"✅ Video saved successfully to: {output_filepath}")
# Retorna o caminho do arquivo gerado para a UI.
return str(output_filepath)
except Exception as e:
print(f"❌ Error during direct inference execution: {e}")
import traceback
traceback.print_exc()
# Propaga o erro para a UI do Gradio, que o exibirá de forma amigável.
raise
|