Test2 / api /seedvr_server.py
EuuIia's picture
Update api/seedvr_server.py
c96bb7c verified
raw
history blame
6.05 kB
import os
import shutil
import subprocess
import sys
import time
import mimetypes
from pathlib import Path
from typing import Optional, Tuple
from huggingface_hub import hf_hub_download
class SeedVRServer:
def __init__(self, **kwargs):
self.SEEDVR_ROOT = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
# Use 8 por padrão, mas nunca maior que o visível no container/host
self.NUM_GPUS_TOTAL = min(int(os.getenv("NUM_GPUS", "8")),
int(os.getenv("MAX_VISIBLE_GPUS", "8")))
print("🚀 SeedVRServer (FP16) inicializando e preparando o ambiente...")
for p in [self.SEEDVR_ROOT.parent, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
p.mkdir(parents=True, exist_ok=True)
self.setup_dependencies()
print("✅ SeedVRServer (FP16) pronto.")
def setup_dependencies(self):
self._ensure_repo()
self._ensure_model()
def _ensure_repo(self) -> None:
if not (self.SEEDVR_ROOT / ".git").exists():
print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
subprocess.run(["git", "clone", "--depth", "1", os.getenv("SEEDVR_GIT_URL",
"https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler"), str(self.SEEDVR_ROOT)], check=True)
else:
print("[SeedVRServer] Repositório SeedVR já existe.")
def _ensure_model(self) -> None:
print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
model_files = {
"seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
"ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
"pos_emb.pt": "ByteDance-Seed/SeedVR2-3B",
"neg_emb.pt": "ByteDance-Seed/SeedVR2-3B",
}
for filename, repo_id in model_files.items():
if not (self.CKPTS_ROOT / filename).exists():
print(f"Baixando {filename} de {repo_id}...")
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=str(self.CKPTS_ROOT),
cache_dir=str(self.HF_HOME_CACHE),
token=os.getenv("HF_TOKEN"),
)
print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
ts = f"{int(time.time())}_{os.urandom(4).hex()}"
job_input_dir = self.INPUT_ROOT / f"job_{ts}"
out_dir = self.OUTPUT_ROOT / f"run_{ts}"
job_input_dir.mkdir(parents=True, exist_ok=True)
out_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(input_file, job_input_dir / Path(input_file).name)
return job_input_dir, out_dir
def _visible_devices_for(self, nproc: int) -> str:
# Mapeia 0..nproc-1 (lógico) para o espaço visível do container
return ",".join(str(i) for i in range(nproc))
def run_inference(
self,
file_path: str,
*,
seed: int,
res_h: int,
res_w: int,
sp_size: int,
fps: Optional[float] = None,
) -> Tuple[Optional[str], Optional[str], Path]:
script = self.SEEDVR_ROOT / "inference_cli.py"
job_input_dir, out_dir = self._prepare_job(file_path)
media_type, _ = mimetypes.guess_type(file_path)
is_image = bool(media_type and media_type.startswith("image"))
# Política: 1 GPU para imagem, 8 GPUs (ou NUM_GPUS_TOTAL) para vídeo
effective_nproc = 1 if is_image else self.NUM_GPUS_TOTAL
effective_sp_size = 1 if is_image else sp_size
output_filename = f"result_{Path(file_path).stem}.mp4"
output_filepath = out_dir / output_filename
cmd = [
"torchrun",
"--standalone",
"--nnodes=1",
f"--nproc-per-node={effective_nproc}",
str(script),
"--video_path", str(file_path),
"--output", str(output_filepath),
"--model_dir", str(self.CKPTS_ROOT),
"--seed", str(seed),
"--resolution", str(res_h),
"--batch_size", str(effective_sp_size),
"--model", "seedvr2_ema_3b_fp16.safetensors",
"--preserve_vram",
"--debug",
"--output_format", "video",
]
# Removido: --cuda_device ... (torchrun + LOCAL_RANK fará o binding correto)
env = os.environ.copy()
# Alinhar espaço lógico de devices com nproc
env["CUDA_VISIBLE_DEVICES"] = self._visible_devices_for(effective_nproc)
# Dicas úteis de debug (opcional):
# env["NCCL_DEBUG"] = "WARN"
# env["CUDA_LAUNCH_BLOCKING"] = "1"
print("[SeedVRServer] Comando:", " ".join(cmd))
print("[SeedVRServer] CUDA_VISIBLE_DEVICES:", env.get("CUDA_VISIBLE_DEVICES", ""))
try:
subprocess.run(
cmd,
cwd=str(self.SEEDVR_ROOT),
check=True,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
if is_image:
# Se output_format=png no CLI, aqui poderia ser diretório; com "video" mantemos mp4, mas
# preservamos compatibilidade caso o CLI mude:
image_dir = output_filepath if output_filepath.suffix == "" else output_filepath.with_suffix("")
return str(image_dir), None, out_dir
else:
return None, str(output_filepath), out_dir
except Exception as e:
print(f"[UI ERROR] A inferência falhou: {e}")
return None, None, out_dir