Aduc-sdr commited on
Commit
2d677c7
·
verified ·
1 Parent(s): 6a4b97d

Update managers/seedvr_manager.py

Browse files
Files changed (1) hide show
  1. managers/seedvr_manager.py +37 -31
managers/seedvr_manager.py CHANGED
@@ -1,19 +1,21 @@
1
  # managers/seedvr_manager.py
2
  #
3
- # Version: 3.2.0 (3B Model Focus)
4
  #
5
- # This version simplifies the manager to exclusively use the SeedVR 3B model.
6
- # The 7B model download and selection logic have been removed to streamline
7
- # the code and reduce resource usage.
 
 
 
8
 
9
- # ... (imports permanecem os mesmos) ...
10
  import torch
11
  import torch.distributed as dist
12
  import os
13
  import gc
14
  import logging
15
  import sys
16
- import subprocess
17
  from pathlib import Path
18
  from urllib.parse import urlparse
19
  from torch.hub import download_url_to_file
@@ -21,10 +23,11 @@ import gradio as gr
21
  import mediapy
22
  from einops import rearrange
23
  import shutil
 
24
  from tools.tensor_utils import wavelet_reconstruction
25
 
26
  logger = logging.getLogger(__name__)
27
- # ... (setup_seedvr_environment_and_dependencies e imports do seedvr permanecem os mesmos) ...
28
  # --- INÍCIO DA SEÇÃO DE GERENCIAMENTO DE DEPENDÊNCIAS E AMBIENTE ---
29
  DEPS_DIR = Path("./deps")
30
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
@@ -64,6 +67,8 @@ def setup_seedvr_environment_and_dependencies():
64
  logger.info("flash-attn installed successfully.")
65
  except subprocess.CalledProcessError as e:
66
  logger.error(f"Failed to install flash-attn. Stderr: {e.stderr}")
 
 
67
 
68
  # 3. Clonar o repositório do SeedVR Space
69
  if not SEEDVR_SPACE_DIR.exists():
@@ -86,19 +91,21 @@ def setup_seedvr_environment_and_dependencies():
86
  sys.path.insert(0, str(SEEDVR_SPACE_DIR.resolve()))
87
  logger.info(f"Added '{SEEDVR_SPACE_DIR.resolve()}' to sys.path.")
88
 
 
89
  setup_seedvr_environment_and_dependencies()
90
 
 
91
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
92
  from common.config import load_config
93
  from common.seed import set_seed
 
94
  from torchvision.io.video import read_video
95
  from omegaconf import OmegaConf
96
- from data.image.transforms.divisible_crop import DivisibleCrop
97
- from data.image.transforms.na_resize import NaResize
98
- from data.video.transforms.rearrange import Rearrange
99
- from torchvision.transforms import Compose, Lambda, Normalize
100
 
101
  class SeedVrManager:
 
102
  def __init__(self, workspace_dir="deformes_workspace"):
103
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
104
  self.runner = None
@@ -108,7 +115,6 @@ class SeedVrManager:
108
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
109
 
110
  def _patch_config_paths(self):
111
- # ... (sem alterações) ...
112
  app_root = Path("/home/user/app")
113
  source_config_dir = SEEDVR_SPACE_DIR / "models" / "video_vae_v3"
114
  target_config_parent_dir = app_root / "models"
@@ -129,36 +135,38 @@ class SeedVrManager:
129
  raise IOError("Could not patch the required SeedVR configuration paths.")
130
 
131
  def _download_models(self):
132
- logger.info("Verifying and downloading SeedVR2 3B model checkpoints...")
133
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
134
  ckpt_dir.mkdir(exist_ok=True)
135
-
136
  pretrain_model_urls = {
137
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
138
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
139
- # 'dit_7b' REMOVIDO
140
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
141
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
142
  }
143
  for key, url in pretrain_model_urls.items():
144
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
145
- logger.info("SeedVR2 3B model checkpoints downloaded successfully.")
146
 
147
- def _initialize_runner(self): # <--- REMOVIDO model_version
148
  if self.runner is not None: return
149
  self._patch_config_paths()
150
  self._download_models()
151
-
152
  if dist.is_available() and not dist.is_initialized():
 
153
  self._original_barrier = dist.barrier
154
  dist.barrier = lambda *args, **kwargs: None
155
-
156
- logger.info("Initializing SeedVR2 3B runner...")
157
- config_path = SEEDVR_SPACE_DIR / 'configs_3b' / 'main.yaml'
158
- checkpoint_path = SEEDVR_SPACE_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
159
-
 
 
 
 
160
  config = load_config(str(config_path))
161
-
162
  self.runner = VideoDiffusionInfer(config)
163
  OmegaConf.set_readonly(self.runner.config, False)
164
  self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
@@ -166,28 +174,27 @@ class SeedVrManager:
166
  if hasattr(self.runner.vae, "set_memory_limit"):
167
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
168
  self.is_initialized = True
169
- logger.info("Runner for SeedVR2 3B initialized and ready.")
170
 
171
  def _unload_runner(self):
172
- # ... (sem alterações) ...
173
  if self.runner is not None:
174
  del self.runner; self.runner = None
175
  gc.collect(); torch.cuda.empty_cache()
176
  self.is_initialized = False
177
  logger.info("SeedVR runner unloaded from VRAM.")
178
  if self._original_barrier is not None:
 
179
  dist.barrier = self._original_barrier
180
  self._original_barrier = None
181
 
182
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
183
- steps: int = 50, seed: int = 666,
184
- progress: gr.Progress = None) -> str: # <--- REMOVIDO model_version
185
  try:
186
- self._initialize_runner() # <--- REMOVIDO model_version
187
  set_seed(seed, same_across_ranks=True)
188
  self.runner.config.diffusion.timesteps.sampling.steps = steps
189
  self.runner.configure_diffusion()
190
- # ... (resto da função sem alterações) ...
191
  video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
192
  res_h, res_w = video_tensor.shape[-2:]
193
  video_transform = Compose([
@@ -231,7 +238,6 @@ class SeedVrManager:
231
  self._unload_runner()
232
 
233
  def _load_file_from_url(url, model_dir='./', file_name=None):
234
- # ... (sem alterações) ...
235
  os.makedirs(model_dir, exist_ok=True)
236
  filename = file_name or os.path.basename(urlparse(url).path)
237
  cached_file = os.path.abspath(os.path.join(model_dir, filename))
 
1
  # managers/seedvr_manager.py
2
  #
3
+ # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 3.1.0 (Full Environment Setup)
6
+ #
7
+ # This version now fully replicates the environment setup from the original
8
+ # SeedVR Space. It sets the necessary torch.distributed environment variables
9
+ # and forces the installation of flash-attn via subprocess, ensuring complete
10
+ # compatibility and resolving runtime dependency issues.
11
 
 
12
  import torch
13
  import torch.distributed as dist
14
  import os
15
  import gc
16
  import logging
17
  import sys
18
+ import subprocess # <--- NOVO IMPORT
19
  from pathlib import Path
20
  from urllib.parse import urlparse
21
  from torch.hub import download_url_to_file
 
23
  import mediapy
24
  from einops import rearrange
25
  import shutil
26
+
27
  from tools.tensor_utils import wavelet_reconstruction
28
 
29
  logger = logging.getLogger(__name__)
30
+
31
  # --- INÍCIO DA SEÇÃO DE GERENCIAMENTO DE DEPENDÊNCIAS E AMBIENTE ---
32
  DEPS_DIR = Path("./deps")
33
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
 
67
  logger.info("flash-attn installed successfully.")
68
  except subprocess.CalledProcessError as e:
69
  logger.error(f"Failed to install flash-attn. Stderr: {e.stderr}")
70
+ # Não lançamos um erro aqui, pois pode não ser fatal em todos os sistemas
71
+ # O import posterior vai falhar se for realmente necessário.
72
 
73
  # 3. Clonar o repositório do SeedVR Space
74
  if not SEEDVR_SPACE_DIR.exists():
 
91
  sys.path.insert(0, str(SEEDVR_SPACE_DIR.resolve()))
92
  logger.info(f"Added '{SEEDVR_SPACE_DIR.resolve()}' to sys.path.")
93
 
94
+ # Executa o setup completo uma única vez
95
  setup_seedvr_environment_and_dependencies()
96
 
97
+ # Agora que o setup está completo, os imports devem funcionar
98
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
99
  from common.config import load_config
100
  from common.seed import set_seed
101
+ # ... (outros imports do seedvr)
102
  from torchvision.io.video import read_video
103
  from omegaconf import OmegaConf
104
+ # --- FIM DA SEÇÃO DE SETUP ---
105
+
 
 
106
 
107
  class SeedVrManager:
108
+ # ... (o resto do código permanece o mesmo da nossa última versão) ...
109
  def __init__(self, workspace_dir="deformes_workspace"):
110
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
111
  self.runner = None
 
115
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
116
 
117
  def _patch_config_paths(self):
 
118
  app_root = Path("/home/user/app")
119
  source_config_dir = SEEDVR_SPACE_DIR / "models" / "video_vae_v3"
120
  target_config_parent_dir = app_root / "models"
 
135
  raise IOError("Could not patch the required SeedVR configuration paths.")
136
 
137
  def _download_models(self):
138
+ logger.info("Verifying and downloading SeedVR2 model checkpoints...")
139
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
140
  ckpt_dir.mkdir(exist_ok=True)
 
141
  pretrain_model_urls = {
142
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
143
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
144
+ 'dit_7b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-7B/resolve/main/seedvr2_ema_7b.pth',
145
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
146
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt'
147
  }
148
  for key, url in pretrain_model_urls.items():
149
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
150
+ logger.info("SeedVR2 model checkpoints downloaded successfully.")
151
 
152
+ def _initialize_runner(self, model_version: str):
153
  if self.runner is not None: return
154
  self._patch_config_paths()
155
  self._download_models()
 
156
  if dist.is_available() and not dist.is_initialized():
157
+ logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
158
  self._original_barrier = dist.barrier
159
  dist.barrier = lambda *args, **kwargs: None
160
+ logger.info(f"Initializing SeedVR2 {model_version} runner...")
161
+ if model_version == '3B':
162
+ config_path = SEEDVR_SPACE_DIR / 'configs_3b' / 'main.yaml'
163
+ checkpoint_path = SEEDVR_SPACE_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
164
+ elif model_version == '7B':
165
+ config_path = SEEDVR_SPACE_DIR / 'configs_7b' / 'main.yaml'
166
+ checkpoint_path = SEEDVR_SPACE_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
167
+ else:
168
+ raise ValueError(f"Unsupported SeedVR model version: {model_version}")
169
  config = load_config(str(config_path))
 
170
  self.runner = VideoDiffusionInfer(config)
171
  OmegaConf.set_readonly(self.runner.config, False)
172
  self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
 
174
  if hasattr(self.runner.vae, "set_memory_limit"):
175
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
176
  self.is_initialized = True
177
+ logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
178
 
179
  def _unload_runner(self):
 
180
  if self.runner is not None:
181
  del self.runner; self.runner = None
182
  gc.collect(); torch.cuda.empty_cache()
183
  self.is_initialized = False
184
  logger.info("SeedVR runner unloaded from VRAM.")
185
  if self._original_barrier is not None:
186
+ logger.info("Restoring original torch.distributed.barrier function.")
187
  dist.barrier = self._original_barrier
188
  self._original_barrier = None
189
 
190
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
191
+ model_version: str = '3B', steps: int = 50, seed: int = 666,
192
+ progress: gr.Progress = None) -> str:
193
  try:
194
+ self._initialize_runner(model_version)
195
  set_seed(seed, same_across_ranks=True)
196
  self.runner.config.diffusion.timesteps.sampling.steps = steps
197
  self.runner.configure_diffusion()
 
198
  video_tensor = read_video(input_video_path, output_format="TCHW")[0] / 255.0
199
  res_h, res_w = video_tensor.shape[-2:]
200
  video_transform = Compose([
 
238
  self._unload_runner()
239
 
240
  def _load_file_from_url(url, model_dir='./', file_name=None):
 
241
  os.makedirs(model_dir, exist_ok=True)
242
  filename = file_name or os.path.basename(urlparse(url).path)
243
  cached_file = os.path.abspath(os.path.join(model_dir, filename))