File size: 5,512 Bytes
7dd33fe a5720bf 7dd33fe a5720bf 7dd33fe a5720bf c96bb7c 7dd33fe a5720bf 7dd33fe a5720bf 7dd33fe a5720bf 7dd33fe a5720bf 7dd33fe a5720bf 7dd33fe a5720bf 7dd33fe a5720bf 7dd33fe a5720bf 7dd33fe a5720bf 7dd33fe a5720bf c96bb7c 7dd33fe a5720bf c96bb7c 7dd33fe c96bb7c a5720bf 7dd33fe c96bb7c 7dd33fe c96bb7c a5720bf 7dd33fe a5720bf 7dd33fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
# api/seedvr_server.py
import os
import sys
import shutil
import mimetypes
import time
from pathlib import Path
from typing import Optional, Tuple
from types import SimpleNamespace
from huggingface_hub import hf_hub_download
# Adiciona dinamicamente o caminho do repositório clonado ao sys.path
# Isso é crucial para que a importação do 'inference_cli' funcione.
SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
if str(SEEDVR_REPO_PATH) not in sys.path:
sys.path.insert(0, str(SEEDVR_REPO_PATH))
# Tenta importar as funções necessárias APÓS a modificação do path.
try:
from inference_cli import run_inference_logic, save_frames_to_video
except ImportError as e:
print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'. Verifique se o repositório em {SEEDVR_REPO_PATH} está correto.")
raise e
class SeedVRServer:
def __init__(self, **kwargs):
"""
Inicializa o servidor, define os caminhos e prepara o ambiente.
"""
self.SEEDVR_ROOT = SEEDVR_REPO_PATH
self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
p.mkdir(parents=True, exist_ok=True)
self.setup_dependencies()
print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")
def setup_dependencies(self):
""" Garante que o repositório e os modelos estão presentes. """
self._ensure_repo()
self._ensure_model()
def _ensure_repo(self) -> None:
""" Clona o repositório do SeedVR se ele não existir. """
if not (self.SEEDVR_ROOT / ".git").exists():
print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
else:
print("[SeedVRServer] Repositório SeedVR já existe.")
def _ensure_model(self) -> None:
""" Baixa os checkpoints do Hugging Face se não existirem localmente. """
print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
model_files = {
"seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
"ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
"pos_emb.pt": "ByteDance-Seed/SeedVR2-3B",
"neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
}
for filename, repo_id in model_files.items():
if not (self.CKPTS_ROOT / filename).exists():
print(f"Baixando {filename} de {repo_id}...")
hf_hub_download(
repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
)
print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
def run_inference_direct(
self,
file_path: str, *,
seed: int, res_h: int, res_w: int, sp_size: int,
fps: Optional[float] = None, progress=None
) -> str:
"""
Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
"""
out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}"
out_dir.mkdir(parents=True, exist_ok=True)
output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
# Simula o objeto 'args' que a função de lógica espera
args = SimpleNamespace(
video_path=file_path,
output=str(output_filepath),
model_dir=str(self.CKPTS_ROOT),
seed=seed,
resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência
batch_size=sp_size,
model="seedvr2_ema_3b_fp16.safetensors",
preserve_vram=True,
debug=True,
cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
skip_first_frames=0,
load_cap=0
)
try:
if progress:
progress(0.1, desc="Iniciando a lógica de inferência...")
# Chama a função importada do script original
result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
if progress:
progress(0.9, desc="Salvando o vídeo resultante...")
final_fps = fps if fps and fps > 0 else original_fps
save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
print(f"✅ Vídeo salvo com sucesso em: {output_filepath}")
return str(output_filepath)
except Exception as e:
print(f"❌ Erro durante a execução direta da inferência: {e}")
import traceback
traceback.print_exc()
raise # Propaga o erro para a UI do Gradio, que o exibirá. |