Aduc-sdr commited on
Commit
4fe7fa0
·
verified ·
1 Parent(s): 09b6a7e

Update managers/seedvr_manager.py

Browse files
Files changed (1) hide show
  1. managers/seedvr_manager.py +39 -42
managers/seedvr_manager.py CHANGED
@@ -2,11 +2,12 @@
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 2.3.5
6
  #
7
- # This version uses the optimal strategy of cloning the self-contained Hugging Face
8
- # Space repository and uses the full, correct import paths to resolve all
9
- # ModuleNotFoundErrors, while retaining necessary runtime patches.
 
10
 
11
  import torch
12
  import torch.distributed as dist
@@ -26,38 +27,41 @@ from tools.tensor_utils import wavelet_reconstruction
26
 
27
  logger = logging.getLogger(__name__)
28
 
29
- # --- Dependency Management ---
30
  DEPS_DIR = Path("./deps")
31
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
32
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
33
- VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"
34
 
35
  def setup_seedvr_dependencies():
36
  """
37
- Ensures the SeedVR Space repository is cloned and available in the sys.path.
 
38
  """
39
  if not SEEDVR_SPACE_DIR.exists():
40
  logger.info(f"SeedVR Space not found at '{SEEDVR_SPACE_DIR}'. Cloning from Hugging Face...")
41
  try:
42
  DEPS_DIR.mkdir(exist_ok=True)
 
43
  subprocess.run(
44
- ["git", "clone", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
45
  check=True, capture_output=True, text=True
46
  )
47
- logger.info("SeedVR Space cloned successfully.")
48
  except subprocess.CalledProcessError as e:
49
  logger.error(f"Failed to clone SeedVR Space. Git stderr: {e.stderr}")
50
  raise RuntimeError("Could not clone the required SeedVR dependency from Hugging Face.")
51
  else:
52
  logger.info("Found local SeedVR Space repository.")
53
 
 
54
  if str(SEEDVR_SPACE_DIR.resolve()) not in sys.path:
55
  sys.path.insert(0, str(SEEDVR_SPACE_DIR.resolve()))
56
  logger.info(f"Added '{SEEDVR_SPACE_DIR.resolve()}' to sys.path.")
57
 
 
58
  setup_seedvr_dependencies()
59
 
60
- # Use full import paths relative to the root of the cloned repository
61
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
62
  from common.config import load_config
63
  from common.seed import set_seed
@@ -67,19 +71,11 @@ from data.video.transforms.rearrange import Rearrange
67
  from torchvision.transforms import Compose, Lambda, Normalize
68
  from torchvision.io.video import read_video
69
  from omegaconf import OmegaConf
 
70
 
71
 
72
- def _load_file_from_url(url, model_dir='./', file_name=None):
73
- os.makedirs(model_dir, exist_ok=True)
74
- filename = file_name or os.path.basename(urlparse(url).path)
75
- cached_file = os.path.abspath(os.path.join(model_dir, filename))
76
- if not os.path.exists(cached_file):
77
- logger.info(f'Downloading: "{url}" to {cached_file}')
78
- download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
79
- return cached_file
80
-
81
  class SeedVrManager:
82
- """Manages the SeedVR model for HD Mastering tasks."""
83
  def __init__(self, workspace_dir="deformes_workspace"):
84
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
85
  self.runner = None
@@ -88,14 +84,12 @@ class SeedVrManager:
88
  self._original_barrier = None
89
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
90
 
91
- def _download_models_and_configs(self):
92
- """Downloads the necessary checkpoints AND the missing VAE config file."""
93
- logger.info("Verifying and downloading SeedVR2 models and configs...")
94
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
95
- config_dir = SEEDVR_SPACE_DIR / 'configs' / 'vae'
96
  ckpt_dir.mkdir(exist_ok=True)
97
- config_dir.mkdir(parents=True, exist_ok=True)
98
- _load_file_from_url(url=VAE_CONFIG_URL, model_dir=str(config_dir))
99
  pretrain_model_urls = {
100
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
101
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
@@ -105,12 +99,12 @@ class SeedVrManager:
105
  }
106
  for key, url in pretrain_model_urls.items():
107
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
108
- logger.info("SeedVR2 models and configs downloaded successfully.")
109
 
110
  def _initialize_runner(self, model_version: str):
111
- """Loads and configures the SeedVR model, with patches for single-GPU inference."""
112
  if self.runner is not None: return
113
- self._download_models_and_configs()
114
 
115
  if dist.is_available() and not dist.is_initialized():
116
  logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
@@ -127,16 +121,9 @@ class SeedVrManager:
127
  else:
128
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
129
 
130
- try:
131
- config = load_config(str(config_path))
132
- except FileNotFoundError:
133
- logger.warning("Caught expected FileNotFoundError. Loading config manually.")
134
- config = OmegaConf.load(str(config_path))
135
- correct_vae_config_path = SEEDVR_SPACE_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
136
- vae_config = OmegaConf.load(str(correct_vae_config_path))
137
- config.vae = vae_config
138
- logger.info("Configuration loaded and patched manually.")
139
-
140
  self.runner = VideoDiffusionInfer(config)
141
  OmegaConf.set_readonly(self.runner.config, False)
142
  self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
@@ -147,7 +134,7 @@ class SeedVrManager:
147
  logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
148
 
149
  def _unload_runner(self):
150
- """Unloads the runner from VRAM and restores patches."""
151
  if self.runner is not None:
152
  del self.runner; self.runner = None
153
  gc.collect(); torch.cuda.empty_cache()
@@ -161,7 +148,7 @@ class SeedVrManager:
161
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
162
  model_version: str = '3B', steps: int = 50, seed: int = 666,
163
  progress: gr.Progress = None) -> str:
164
- """Applies HD enhancement to a video."""
165
  try:
166
  self._initialize_runner(model_version)
167
  set_seed(seed, same_across_ranks=True)
@@ -209,5 +196,15 @@ class SeedVrManager:
209
  finally:
210
  self._unload_runner()
211
 
212
- # --- Singleton Instance ---
 
 
 
 
 
 
 
 
 
 
213
  seedvr_manager_singleton = SeedVrManager()
 
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 3.0.0 (Robust Dependency Management)
6
  #
7
+ # This version implements a robust dependency management strategy by cloning the
8
+ # entire SeedVR Hugging Face Space repository. This ensures that all necessary
9
+ # modules and sub-packages are locally available, resolving the 'flash_attn'
10
+ # ModuleNotFoundError and other potential import issues without relying on pip.
11
 
12
  import torch
13
  import torch.distributed as dist
 
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
+ # --- INÍCIO DA NOVA SEÇÃO DE GERENCIAMENTO DE DEPENDÊNCIAS ---
31
  DEPS_DIR = Path("./deps")
32
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
33
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
 
34
 
35
  def setup_seedvr_dependencies():
36
  """
37
+ Ensures the SeedVR Space repository is cloned and its path is added to sys.path,
38
+ making all its internal modules importable.
39
  """
40
  if not SEEDVR_SPACE_DIR.exists():
41
  logger.info(f"SeedVR Space not found at '{SEEDVR_SPACE_DIR}'. Cloning from Hugging Face...")
42
  try:
43
  DEPS_DIR.mkdir(exist_ok=True)
44
+ # Usamos --depth 1 para um clone mais rápido, já que não precisamos do histórico
45
  subprocess.run(
46
+ ["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
47
  check=True, capture_output=True, text=True
48
  )
49
+ logger.info("SeedVR Space repository cloned successfully.")
50
  except subprocess.CalledProcessError as e:
51
  logger.error(f"Failed to clone SeedVR Space. Git stderr: {e.stderr}")
52
  raise RuntimeError("Could not clone the required SeedVR dependency from Hugging Face.")
53
  else:
54
  logger.info("Found local SeedVR Space repository.")
55
 
56
+ # Adiciona o diretório clonado ao path do Python para que os imports funcionem
57
  if str(SEEDVR_SPACE_DIR.resolve()) not in sys.path:
58
  sys.path.insert(0, str(SEEDVR_SPACE_DIR.resolve()))
59
  logger.info(f"Added '{SEEDVR_SPACE_DIR.resolve()}' to sys.path.")
60
 
61
+ # Executa a configuração das dependências assim que o módulo é importado
62
  setup_seedvr_dependencies()
63
 
64
+ # Agora que o path está ajustado, podemos importar os módulos do SeedVR
65
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
66
  from common.config import load_config
67
  from common.seed import set_seed
 
71
  from torchvision.transforms import Compose, Lambda, Normalize
72
  from torchvision.io.video import read_video
73
  from omegaconf import OmegaConf
74
+ # --- FIM DA NOVA SEÇÃO ---
75
 
76
 
 
 
 
 
 
 
 
 
 
77
  class SeedVrManager:
78
+ # ... (o resto da classe permanece o mesmo, mas a lógica de download de configs pode ser simplificada) ...
79
  def __init__(self, workspace_dir="deformes_workspace"):
80
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
81
  self.runner = None
 
84
  self._original_barrier = None
85
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
86
 
87
+ def _download_models(self):
88
+ """Downloads only the necessary model checkpoints."""
89
+ logger.info("Verifying and downloading SeedVR2 model checkpoints...")
90
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
 
91
  ckpt_dir.mkdir(exist_ok=True)
92
+
 
93
  pretrain_model_urls = {
94
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
95
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
 
99
  }
100
  for key, url in pretrain_model_urls.items():
101
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
102
+ logger.info("SeedVR2 model checkpoints downloaded successfully.")
103
 
104
  def _initialize_runner(self, model_version: str):
105
+ """Loads and configures the SeedVR model."""
106
  if self.runner is not None: return
107
+ self._download_models() # Garante que os pesos estão baixados
108
 
109
  if dist.is_available() and not dist.is_initialized():
110
  logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
 
121
  else:
122
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
123
 
124
+ # Como clonamos o repo, os arquivos de config estarão lá.
125
+ config = load_config(str(config_path))
126
+
 
 
 
 
 
 
 
127
  self.runner = VideoDiffusionInfer(config)
128
  OmegaConf.set_readonly(self.runner.config, False)
129
  self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
 
134
  logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
135
 
136
  def _unload_runner(self):
137
+ # ... (sem alterações aqui)
138
  if self.runner is not None:
139
  del self.runner; self.runner = None
140
  gc.collect(); torch.cuda.empty_cache()
 
148
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
149
  model_version: str = '3B', steps: int = 50, seed: int = 666,
150
  progress: gr.Progress = None) -> str:
151
+ # ... (sem alterações aqui)
152
  try:
153
  self._initialize_runner(model_version)
154
  set_seed(seed, same_across_ranks=True)
 
196
  finally:
197
  self._unload_runner()
198
 
199
+ # Helper function (pode ser movida para um utils se necessário)
200
+ def _load_file_from_url(url, model_dir='./', file_name=None):
201
+ os.makedirs(model_dir, exist_ok=True)
202
+ filename = file_name or os.path.basename(urlparse(url).path)
203
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
204
+ if not os.path.exists(cached_file):
205
+ logger.info(f'Downloading: "{url}" to {cached_file}')
206
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
207
+ return cached_file
208
+
209
+ # --- Singleton Instantiation ---
210
  seedvr_manager_singleton = SeedVrManager()