EuuIia commited on
Commit
7dd33fe
·
verified ·
1 Parent(s): 3e978e1

Update api/seedvr_server.py

Browse files
Files changed (1) hide show
  1. api/seedvr_server.py +86 -105
api/seedvr_server.py CHANGED
@@ -1,147 +1,128 @@
 
 
1
  import os
2
- import shutil
3
- import subprocess
4
  import sys
5
- import time
6
  import mimetypes
 
7
  from pathlib import Path
8
  from typing import Optional, Tuple
 
 
9
  from huggingface_hub import hf_hub_download
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class SeedVRServer:
12
  def __init__(self, **kwargs):
13
- self.SEEDVR_ROOT = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
14
- self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
 
 
 
15
  self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
16
  self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
17
  self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
18
- # Use 8 por padrão, mas nunca maior que o visível no container/host
19
- self.NUM_GPUS_TOTAL = min(int(os.getenv("NUM_GPUS", "8")),
20
- int(os.getenv("MAX_VISIBLE_GPUS", "8")))
21
- print("🚀 SeedVRServer (FP16) inicializando e preparando o ambiente...")
22
-
23
- for p in [self.SEEDVR_ROOT.parent, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
24
  p.mkdir(parents=True, exist_ok=True)
25
-
26
  self.setup_dependencies()
27
- print("✅ SeedVRServer (FP16) pronto.")
28
 
29
  def setup_dependencies(self):
 
30
  self._ensure_repo()
31
  self._ensure_model()
32
 
33
  def _ensure_repo(self) -> None:
 
34
  if not (self.SEEDVR_ROOT / ".git").exists():
35
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
36
- subprocess.run(["git", "clone", "--depth", "1", os.getenv("SEEDVR_GIT_URL",
37
- "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler"), str(self.SEEDVR_ROOT)], check=True)
38
  else:
39
  print("[SeedVRServer] Repositório SeedVR já existe.")
40
 
41
  def _ensure_model(self) -> None:
 
42
  print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
43
  model_files = {
44
  "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
45
  "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
46
  "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B",
47
- "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B",
48
  }
49
  for filename, repo_id in model_files.items():
50
  if not (self.CKPTS_ROOT / filename).exists():
51
  print(f"Baixando {filename} de {repo_id}...")
52
  hf_hub_download(
53
- repo_id=repo_id,
54
- filename=filename,
55
- local_dir=str(self.CKPTS_ROOT),
56
- cache_dir=str(self.HF_HOME_CACHE),
57
- token=os.getenv("HF_TOKEN"),
58
  )
59
  print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
60
-
61
- def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
62
- ts = f"{int(time.time())}_{os.urandom(4).hex()}"
63
- job_input_dir = self.INPUT_ROOT / f"job_{ts}"
64
- out_dir = self.OUTPUT_ROOT / f"run_{ts}"
65
- job_input_dir.mkdir(parents=True, exist_ok=True)
66
- out_dir.mkdir(parents=True, exist_ok=True)
67
- shutil.copy2(input_file, job_input_dir / Path(input_file).name)
68
- return job_input_dir, out_dir
69
-
70
- def _visible_devices_for(self, nproc: int) -> str:
71
- # Mapeia 0..nproc-1 (lógico) para o espaço visível do container
72
- return ",".join(str(i) for i in range(nproc))
73
-
74
- def run_inference(
75
  self,
76
- file_path: str,
77
- *,
78
- seed: int,
79
- res_h: int,
80
- res_w: int,
81
- sp_size: int,
82
- fps: Optional[float] = None,
83
- ) -> Tuple[Optional[str], Optional[str], Path]:
84
-
85
- script = self.SEEDVR_ROOT / "inference_cli.py"
86
- job_input_dir, out_dir = self._prepare_job(file_path)
87
-
88
- media_type, _ = mimetypes.guess_type(file_path)
89
- is_image = bool(media_type and media_type.startswith("image"))
90
-
91
- # Política: 1 GPU para imagem, 8 GPUs (ou NUM_GPUS_TOTAL) para vídeo
92
- effective_nproc = 1 if is_image else self.NUM_GPUS_TOTAL
93
- effective_sp_size = 1 if is_image else sp_size
94
-
95
- output_filename = f"result_{Path(file_path).stem}.mp4"
96
- output_filepath = out_dir / output_filename
97
-
98
- cmd = [
99
- "torchrun",
100
- "--standalone",
101
- "--nnodes=1",
102
- f"--nproc-per-node={effective_nproc}",
103
- str(script),
104
- "--video_path", str(file_path),
105
- "--output", str(output_filepath),
106
- "--model_dir", str(self.CKPTS_ROOT),
107
- "--seed", str(seed),
108
- "--resolution", str(res_h),
109
- "--batch_size", str(effective_sp_size),
110
- "--model", "seedvr2_ema_3b_fp16.safetensors",
111
- "--preserve_vram",
112
- "--debug",
113
- "--output_format", "video",
114
- ]
115
- # Removido: --cuda_device ... (torchrun + LOCAL_RANK fará o binding correto)
116
-
117
- env = os.environ.copy()
118
- # Alinhar espaço lógico de devices com nproc
119
- env["CUDA_VISIBLE_DEVICES"] = self._visible_devices_for(effective_nproc)
120
- # Dicas úteis de debug (opcional):
121
- # env["NCCL_DEBUG"] = "WARN"
122
- # env["CUDA_LAUNCH_BLOCKING"] = "1"
123
-
124
- print("[SeedVRServer] Comando:", " ".join(cmd))
125
- print("[SeedVRServer] CUDA_VISIBLE_DEVICES:", env.get("CUDA_VISIBLE_DEVICES", ""))
126
 
127
  try:
128
- subprocess.run(
129
- cmd,
130
- cwd=str(self.SEEDVR_ROOT),
131
- check=True,
132
- env=env,
133
- stdout=sys.stdout,
134
- stderr=sys.stderr,
135
- )
136
-
137
- if is_image:
138
- # Se output_format=png no CLI, aqui poderia ser diretório; com "video" mantemos mp4, mas
139
- # preservamos compatibilidade caso o CLI mude:
140
- image_dir = output_filepath if output_filepath.suffix == "" else output_filepath.with_suffix("")
141
- return str(image_dir), None, out_dir
142
- else:
143
- return None, str(output_filepath), out_dir
144
-
145
  except Exception as e:
146
- print(f"[UI ERROR] A inferência falhou: {e}")
147
- return None, None, out_dir
 
 
 
1
+ # api/seedvr_server.py
2
+
3
  import os
 
 
4
  import sys
5
+ import shutil
6
  import mimetypes
7
+ import time
8
  from pathlib import Path
9
  from typing import Optional, Tuple
10
+ from types import SimpleNamespace
11
+
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # Adiciona dinamicamente o caminho do repositório clonado ao sys.path
15
+ # Isso é crucial para que a importação do 'inference_cli' funcione.
16
+ SEEDVR_REPO_PATH = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
17
+ if str(SEEDVR_REPO_PATH) not in sys.path:
18
+ sys.path.insert(0, str(SEEDVR_REPO_PATH))
19
+
20
+ # Tenta importar as funções necessárias APÓS a modificação do path.
21
+ try:
22
+ from inference_cli import run_inference_logic, save_frames_to_video
23
+ except ImportError as e:
24
+ print(f"ERRO FATAL: Não foi possível importar de 'inference_cli.py'. Verifique se o repositório em {SEEDVR_REPO_PATH} está correto.")
25
+ raise e
26
+
27
  class SeedVRServer:
28
  def __init__(self, **kwargs):
29
+ """
30
+ Inicializa o servidor, define os caminhos e prepara o ambiente.
31
+ """
32
+ self.SEEDVR_ROOT = SEEDVR_REPO_PATH
33
+ self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
34
  self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
35
  self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
36
  self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
37
+ self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
38
+ self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
39
+
40
+ print("🚀 SeedVRServer (Modo de Chamada Direta) inicializando...")
41
+ for p in [self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
 
42
  p.mkdir(parents=True, exist_ok=True)
43
+
44
  self.setup_dependencies()
45
+ print("✅ SeedVRServer (Modo de Chamada Direta) pronto.")
46
 
47
  def setup_dependencies(self):
48
+ """ Garante que o repositório e os modelos estão presentes. """
49
  self._ensure_repo()
50
  self._ensure_model()
51
 
52
  def _ensure_repo(self) -> None:
53
+ """ Clona o repositório do SeedVR se ele não existir. """
54
  if not (self.SEEDVR_ROOT / ".git").exists():
55
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
56
+ subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
 
57
  else:
58
  print("[SeedVRServer] Repositório SeedVR já existe.")
59
 
60
  def _ensure_model(self) -> None:
61
+ """ Baixa os checkpoints do Hugging Face se não existirem localmente. """
62
  print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
63
  model_files = {
64
  "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
65
  "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
66
  "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B",
67
+ "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
68
  }
69
  for filename, repo_id in model_files.items():
70
  if not (self.CKPTS_ROOT / filename).exists():
71
  print(f"Baixando {filename} de {repo_id}...")
72
  hf_hub_download(
73
+ repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT),
74
+ cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN")
 
 
 
75
  )
76
  print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
77
+
78
+ def run_inference_direct(
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  self,
80
+ file_path: str, *,
81
+ seed: int, res_h: int, res_w: int, sp_size: int,
82
+ fps: Optional[float] = None, progress=None
83
+ ) -> str:
84
+ """
85
+ Executa a inferência diretamente no mesmo processo e retorna o caminho do arquivo de saída.
86
+ """
87
+ out_dir = self.OUTPUT_ROOT / f"run_{int(time.time())}"
88
+ out_dir.mkdir(parents=True, exist_ok=True)
89
+ output_filepath = out_dir / f"result_{Path(file_path).stem}.mp4"
90
+
91
+ # Simula o objeto 'args' que a função de lógica espera
92
+ args = SimpleNamespace(
93
+ video_path=file_path,
94
+ output=str(output_filepath),
95
+ model_dir=str(self.CKPTS_ROOT),
96
+ seed=seed,
97
+ resolution=res_h, # O script do SeedVR usa a altura (lado menor) como referência
98
+ batch_size=sp_size,
99
+ model="seedvr2_ema_3b_fp16.safetensors",
100
+ preserve_vram=True,
101
+ debug=True,
102
+ cuda_device=",".join(map(str, range(self.NUM_GPUS_TOTAL))),
103
+ skip_first_frames=0,
104
+ load_cap=0
105
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  try:
108
+ if progress:
109
+ progress(0.1, desc="Iniciando a lógica de inferência...")
110
+
111
+ # Chama a função importada do script original
112
+ result_tensor, original_fps, _, _ = run_inference_logic(args, progress_callback=progress)
113
+
114
+ if progress:
115
+ progress(0.9, desc="Salvando o vídeo resultante...")
116
+
117
+ final_fps = fps if fps and fps > 0 else original_fps
118
+ save_frames_to_video(result_tensor, str(output_filepath), final_fps, args.debug)
119
+
120
+ print(f"✅ Vídeo salvo com sucesso em: {output_filepath}")
121
+
122
+ return str(output_filepath)
123
+
 
124
  except Exception as e:
125
+ print(f" Erro durante a execução direta da inferência: {e}")
126
+ import traceback
127
+ traceback.print_exc()
128
+ raise # Propaga o erro para a UI do Gradio, que o exibirá.