EuuIia commited on
Commit
c96bb7c
·
verified ·
1 Parent(s): b91aa33

Update api/seedvr_server.py

Browse files
Files changed (1) hide show
  1. api/seedvr_server.py +71 -42
api/seedvr_server.py CHANGED
@@ -5,55 +5,59 @@ import sys
5
  import time
6
  import mimetypes
7
  from pathlib import Path
8
- from typing import List, Optional, Tuple
9
-
10
  from huggingface_hub import hf_hub_download
11
 
12
  class SeedVRServer:
13
  def __init__(self, **kwargs):
14
  self.SEEDVR_ROOT = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
15
- # Apontamos para o nosso diretório de checkpoints customizado
16
- self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
17
  self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
18
  self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
19
  self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
20
- self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
21
- self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
22
-
23
  print("🚀 SeedVRServer (FP16) inicializando e preparando o ambiente...")
 
24
  for p in [self.SEEDVR_ROOT.parent, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
25
  p.mkdir(parents=True, exist_ok=True)
26
-
27
  self.setup_dependencies()
28
  print("✅ SeedVRServer (FP16) pronto.")
29
 
30
  def setup_dependencies(self):
31
  self._ensure_repo()
32
- # O monkey patch agora é feito pelo start_seedvr.sh, não mais aqui.
33
  self._ensure_model()
34
 
35
  def _ensure_repo(self) -> None:
36
  if not (self.SEEDVR_ROOT / ".git").exists():
37
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
38
- subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
 
39
  else:
40
  print("[SeedVRServer] Repositório SeedVR já existe.")
41
 
42
  def _ensure_model(self) -> None:
43
- """Baixa os arquivos de modelo FP16 otimizados e suas dependências."""
44
  print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
45
-
46
  model_files = {
47
- "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses", "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
48
- "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B", "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
 
 
49
  }
50
-
51
  for filename, repo_id in model_files.items():
52
  if not (self.CKPTS_ROOT / filename).exists():
53
  print(f"Baixando {filename} de {repo_id}...")
54
- hf_hub_download(repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT), cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN"))
 
 
 
 
 
 
55
  print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
56
-
57
  def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
58
  ts = f"{int(time.time())}_{os.urandom(4).hex()}"
59
  job_input_dir = self.INPUT_ROOT / f"job_{ts}"
@@ -62,57 +66,82 @@ class SeedVRServer:
62
  out_dir.mkdir(parents=True, exist_ok=True)
63
  shutil.copy2(input_file, job_input_dir / Path(input_file).name)
64
  return job_input_dir, out_dir
65
-
66
- def run_inference(self, file_path: str, *, seed: int, res_h: int, res_w: int, sp_size: int, fps: Optional[float] = None) -> Tuple[Optional[str], Optional[str], Path]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  script = self.SEEDVR_ROOT / "inference_cli.py"
68
  job_input_dir, out_dir = self._prepare_job(file_path)
69
 
70
  media_type, _ = mimetypes.guess_type(file_path)
71
- is_image = media_type and media_type.startswith("image")
72
 
 
73
  effective_nproc = 1 if is_image else self.NUM_GPUS_TOTAL
74
  effective_sp_size = 1 if is_image else sp_size
75
 
76
-
77
  output_filename = f"result_{Path(file_path).stem}.mp4"
78
  output_filepath = out_dir / output_filename
79
 
80
-
81
  cmd = [
82
- "torchrun", "--standalone", "--nnodes=1", f"--nproc-per-node={effective_nproc}", str(script),
 
 
 
 
83
  "--video_path", str(file_path),
84
  "--output", str(output_filepath),
85
  "--model_dir", str(self.CKPTS_ROOT),
86
  "--seed", str(seed),
87
- "--cuda_device", str("0,1,2,3"),
88
  "--resolution", str(res_h),
89
  "--batch_size", str(effective_sp_size),
90
  "--model", "seedvr2_ema_3b_fp16.safetensors",
91
- "--preserve_vram",
92
  "--debug",
93
- "--output_format", "video"
94
  ]
 
 
 
 
 
 
 
 
95
 
96
- #if effective_nproc > 1:
97
- # cmd = [c for c in cmd if not c.startswith('--cuda_device')]
98
-
99
  print("[SeedVRServer] Comando:", " ".join(cmd))
100
-
101
-
102
- print("SeedVRServer Comando:", " ".join(cmd))
103
  try:
104
- subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=os.environ.copy(), stdout=sys.stdout, stderr=sys.stderr)
105
- # Constrói a tupla de retorno de forma determinística
 
 
 
 
 
 
 
106
  if is_image:
107
- # CLI salva PNGs em diretório args.output (tratado como diretório quando outputformat=png)
 
108
  image_dir = output_filepath if output_filepath.suffix == "" else output_filepath.with_suffix("")
109
- return str(image_dir), None, outdir
110
  else:
111
- # CLI salva vídeo exatamente em output_filepath
112
- return None, str(output_filepath), outdir
113
  except Exception as e:
114
  print(f"[UI ERROR] A inferência falhou: {e}")
115
- return None, None, None
116
-
117
-
118
-
 
5
  import time
6
  import mimetypes
7
  from pathlib import Path
8
+ from typing import Optional, Tuple
 
9
  from huggingface_hub import hf_hub_download
10
 
11
  class SeedVRServer:
12
  def __init__(self, **kwargs):
13
  self.SEEDVR_ROOT = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
14
+ self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
 
15
  self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
16
  self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
17
  self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
18
+ # Use 8 por padrão, mas nunca maior que o visível no container/host
19
+ self.NUM_GPUS_TOTAL = min(int(os.getenv("NUM_GPUS", "8")),
20
+ int(os.getenv("MAX_VISIBLE_GPUS", "8")))
21
  print("🚀 SeedVRServer (FP16) inicializando e preparando o ambiente...")
22
+
23
  for p in [self.SEEDVR_ROOT.parent, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
24
  p.mkdir(parents=True, exist_ok=True)
25
+
26
  self.setup_dependencies()
27
  print("✅ SeedVRServer (FP16) pronto.")
28
 
29
  def setup_dependencies(self):
30
  self._ensure_repo()
 
31
  self._ensure_model()
32
 
33
  def _ensure_repo(self) -> None:
34
  if not (self.SEEDVR_ROOT / ".git").exists():
35
  print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
36
+ subprocess.run(["git", "clone", "--depth", "1", os.getenv("SEEDVR_GIT_URL",
37
+ "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler"), str(self.SEEDVR_ROOT)], check=True)
38
  else:
39
  print("[SeedVRServer] Repositório SeedVR já existe.")
40
 
41
  def _ensure_model(self) -> None:
 
42
  print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
 
43
  model_files = {
44
+ "seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
45
+ "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
46
+ "pos_emb.pt": "ByteDance-Seed/SeedVR2-3B",
47
+ "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B",
48
  }
 
49
  for filename, repo_id in model_files.items():
50
  if not (self.CKPTS_ROOT / filename).exists():
51
  print(f"Baixando {filename} de {repo_id}...")
52
+ hf_hub_download(
53
+ repo_id=repo_id,
54
+ filename=filename,
55
+ local_dir=str(self.CKPTS_ROOT),
56
+ cache_dir=str(self.HF_HOME_CACHE),
57
+ token=os.getenv("HF_TOKEN"),
58
+ )
59
  print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
60
+
61
  def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
62
  ts = f"{int(time.time())}_{os.urandom(4).hex()}"
63
  job_input_dir = self.INPUT_ROOT / f"job_{ts}"
 
66
  out_dir.mkdir(parents=True, exist_ok=True)
67
  shutil.copy2(input_file, job_input_dir / Path(input_file).name)
68
  return job_input_dir, out_dir
69
+
70
+ def _visible_devices_for(self, nproc: int) -> str:
71
+ # Mapeia 0..nproc-1 (lógico) para o espaço visível do container
72
+ return ",".join(str(i) for i in range(nproc))
73
+
74
+ def run_inference(
75
+ self,
76
+ file_path: str,
77
+ *,
78
+ seed: int,
79
+ res_h: int,
80
+ res_w: int,
81
+ sp_size: int,
82
+ fps: Optional[float] = None,
83
+ ) -> Tuple[Optional[str], Optional[str], Path]:
84
+
85
  script = self.SEEDVR_ROOT / "inference_cli.py"
86
  job_input_dir, out_dir = self._prepare_job(file_path)
87
 
88
  media_type, _ = mimetypes.guess_type(file_path)
89
+ is_image = bool(media_type and media_type.startswith("image"))
90
 
91
+ # Política: 1 GPU para imagem, 8 GPUs (ou NUM_GPUS_TOTAL) para vídeo
92
  effective_nproc = 1 if is_image else self.NUM_GPUS_TOTAL
93
  effective_sp_size = 1 if is_image else sp_size
94
 
 
95
  output_filename = f"result_{Path(file_path).stem}.mp4"
96
  output_filepath = out_dir / output_filename
97
 
 
98
  cmd = [
99
+ "torchrun",
100
+ "--standalone",
101
+ "--nnodes=1",
102
+ f"--nproc-per-node={effective_nproc}",
103
+ str(script),
104
  "--video_path", str(file_path),
105
  "--output", str(output_filepath),
106
  "--model_dir", str(self.CKPTS_ROOT),
107
  "--seed", str(seed),
 
108
  "--resolution", str(res_h),
109
  "--batch_size", str(effective_sp_size),
110
  "--model", "seedvr2_ema_3b_fp16.safetensors",
111
+ "--preserve_vram",
112
  "--debug",
113
+ "--output_format", "video",
114
  ]
115
+ # Removido: --cuda_device ... (torchrun + LOCAL_RANK fará o binding correto)
116
+
117
+ env = os.environ.copy()
118
+ # Alinhar espaço lógico de devices com nproc
119
+ env["CUDA_VISIBLE_DEVICES"] = self._visible_devices_for(effective_nproc)
120
+ # Dicas úteis de debug (opcional):
121
+ # env["NCCL_DEBUG"] = "WARN"
122
+ # env["CUDA_LAUNCH_BLOCKING"] = "1"
123
 
 
 
 
124
  print("[SeedVRServer] Comando:", " ".join(cmd))
125
+ print("[SeedVRServer] CUDA_VISIBLE_DEVICES:", env.get("CUDA_VISIBLE_DEVICES", ""))
126
+
 
127
  try:
128
+ subprocess.run(
129
+ cmd,
130
+ cwd=str(self.SEEDVR_ROOT),
131
+ check=True,
132
+ env=env,
133
+ stdout=sys.stdout,
134
+ stderr=sys.stderr,
135
+ )
136
+
137
  if is_image:
138
+ # Se output_format=png no CLI, aqui poderia ser diretório; com "video" mantemos mp4, mas
139
+ # preservamos compatibilidade caso o CLI mude:
140
  image_dir = output_filepath if output_filepath.suffix == "" else output_filepath.with_suffix("")
141
+ return str(image_dir), None, out_dir
142
  else:
143
+ return None, str(output_filepath), out_dir
144
+
145
  except Exception as e:
146
  print(f"[UI ERROR] A inferência falhou: {e}")
147
+ return None, None, out_dir