Aduc-sdr commited on
Commit
6e99af4
·
verified ·
1 Parent(s): 4fe7fa0

Update managers/seedvr_manager.py

Browse files
Files changed (1) hide show
  1. managers/seedvr_manager.py +51 -27
managers/seedvr_manager.py CHANGED
@@ -2,12 +2,12 @@
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 3.0.0 (Robust Dependency Management)
6
  #
7
- # This version implements a robust dependency management strategy by cloning the
8
- # entire SeedVR Hugging Face Space repository. This ensures that all necessary
9
- # modules and sub-packages are locally available, resolving the 'flash_attn'
10
- # ModuleNotFoundError and other potential import issues without relying on pip.
11
 
12
  import torch
13
  import torch.distributed as dist
@@ -22,26 +22,23 @@ from torch.hub import download_url_to_file
22
  import gradio as gr
23
  import mediapy
24
  from einops import rearrange
 
25
 
26
  from tools.tensor_utils import wavelet_reconstruction
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
- # --- INÍCIO DA NOVA SEÇÃO DE GERENCIAMENTO DE DEPENDÊNCIAS ---
31
  DEPS_DIR = Path("./deps")
32
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
33
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
34
 
35
  def setup_seedvr_dependencies():
36
- """
37
- Ensures the SeedVR Space repository is cloned and its path is added to sys.path,
38
- making all its internal modules importable.
39
- """
40
  if not SEEDVR_SPACE_DIR.exists():
41
  logger.info(f"SeedVR Space not found at '{SEEDVR_SPACE_DIR}'. Cloning from Hugging Face...")
42
  try:
43
  DEPS_DIR.mkdir(exist_ok=True)
44
- # Usamos --depth 1 para um clone mais rápido, já que não precisamos do histórico
45
  subprocess.run(
46
  ["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
47
  check=True, capture_output=True, text=True
@@ -53,30 +50,24 @@ def setup_seedvr_dependencies():
53
  else:
54
  logger.info("Found local SeedVR Space repository.")
55
 
56
- # Adiciona o diretório clonado ao path do Python para que os imports funcionem
57
  if str(SEEDVR_SPACE_DIR.resolve()) not in sys.path:
58
  sys.path.insert(0, str(SEEDVR_SPACE_DIR.resolve()))
59
  logger.info(f"Added '{SEEDVR_SPACE_DIR.resolve()}' to sys.path.")
60
 
61
- # Executa a configuração das dependências assim que o módulo é importado
62
  setup_seedvr_dependencies()
63
 
64
- # Agora que o path está ajustado, podemos importar os módulos do SeedVR
65
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
66
  from common.config import load_config
67
  from common.seed import set_seed
68
  from data.image.transforms.divisible_crop import DivisibleCrop
69
- from data.image.transforms.na_resize import NaResize
70
- from data.video.transforms.rearrange import Rearrange
71
- from torchvision.transforms import Compose, Lambda, Normalize
72
  from torchvision.io.video import read_video
73
  from omegaconf import OmegaConf
74
- # --- FIM DA NOVA SEÇÃO ---
75
 
76
 
77
  class SeedVrManager:
78
- # ... (o resto da classe permanece o mesmo, mas a lógica de download de configs pode ser simplificada) ...
79
  def __init__(self, workspace_dir="deformes_workspace"):
 
80
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
81
  self.runner = None
82
  self.workspace_dir = workspace_dir
@@ -84,8 +75,39 @@ class SeedVrManager:
84
  self._original_barrier = None
85
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def _download_models(self):
88
- """Downloads only the necessary model checkpoints."""
89
  logger.info("Verifying and downloading SeedVR2 model checkpoints...")
90
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
91
  ckpt_dir.mkdir(exist_ok=True)
@@ -104,9 +126,14 @@ class SeedVrManager:
104
  def _initialize_runner(self, model_version: str):
105
  """Loads and configures the SeedVR model."""
106
  if self.runner is not None: return
107
- self._download_models() # Garante que os pesos estão baixados
 
 
 
 
108
 
109
  if dist.is_available() and not dist.is_initialized():
 
110
  logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
111
  self._original_barrier = dist.barrier
112
  dist.barrier = lambda *args, **kwargs: None
@@ -121,7 +148,7 @@ class SeedVrManager:
121
  else:
122
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
123
 
124
- # Como clonamos o repo, os arquivos de config estarão lá.
125
  config = load_config(str(config_path))
126
 
127
  self.runner = VideoDiffusionInfer(config)
@@ -133,8 +160,8 @@ class SeedVrManager:
133
  self.is_initialized = True
134
  logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
135
 
 
136
  def _unload_runner(self):
137
- # ... (sem alterações aqui)
138
  if self.runner is not None:
139
  del self.runner; self.runner = None
140
  gc.collect(); torch.cuda.empty_cache()
@@ -148,7 +175,6 @@ class SeedVrManager:
148
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
149
  model_version: str = '3B', steps: int = 50, seed: int = 666,
150
  progress: gr.Progress = None) -> str:
151
- # ... (sem alterações aqui)
152
  try:
153
  self._initialize_runner(model_version)
154
  set_seed(seed, same_across_ranks=True)
@@ -192,11 +218,10 @@ class SeedVrManager:
192
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
193
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
194
  logger.info(f"HD Mastered video saved to: {output_video_path}")
195
- return output_video_path
196
  finally:
197
  self._unload_runner()
198
 
199
- # Helper function (pode ser movida para um utils se necessário)
200
  def _load_file_from_url(url, model_dir='./', file_name=None):
201
  os.makedirs(model_dir, exist_ok=True)
202
  filename = file_name or os.path.basename(urlparse(url).path)
@@ -206,5 +231,4 @@ def _load_file_from_url(url, model_dir='./', file_name=None):
206
  download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
207
  return cached_file
208
 
209
- # --- Singleton Instantiation ---
210
  seedvr_manager_singleton = SeedVrManager()
 
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 3.0.1 (Config Path Patch)
6
  #
7
+ # This version adds a path patching mechanism. It copies the necessary VAE
8
+ # configuration files from the cloned SeedVR dependency directory to the
9
+ # location where the SeedVR code hardcodedly expects them, resolving the
10
+ # FileNotFoundError during initialization.
11
 
12
  import torch
13
  import torch.distributed as dist
 
22
  import gradio as gr
23
  import mediapy
24
  from einops import rearrange
25
+ import shutil # <--- NOVO IMPORT
26
 
27
  from tools.tensor_utils import wavelet_reconstruction
28
 
29
  logger = logging.getLogger(__name__)
30
 
31
+ # --- Gerenciamento de Dependências (sem alterações) ---
32
  DEPS_DIR = Path("./deps")
33
  SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
34
  SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
35
 
36
  def setup_seedvr_dependencies():
37
+ # ... (sem alterações aqui)
 
 
 
38
  if not SEEDVR_SPACE_DIR.exists():
39
  logger.info(f"SeedVR Space not found at '{SEEDVR_SPACE_DIR}'. Cloning from Hugging Face...")
40
  try:
41
  DEPS_DIR.mkdir(exist_ok=True)
 
42
  subprocess.run(
43
  ["git", "clone", "--depth", "1", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
44
  check=True, capture_output=True, text=True
 
50
  else:
51
  logger.info("Found local SeedVR Space repository.")
52
 
 
53
  if str(SEEDVR_SPACE_DIR.resolve()) not in sys.path:
54
  sys.path.insert(0, str(SEEDVR_SPACE_DIR.resolve()))
55
  logger.info(f"Added '{SEEDVR_SPACE_DIR.resolve()}' to sys.path.")
56
 
 
57
  setup_seedvr_dependencies()
58
 
 
59
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
60
  from common.config import load_config
61
  from common.seed import set_seed
62
  from data.image.transforms.divisible_crop import DivisibleCrop
63
+ # ... (outros imports do seedvr sem alterações)
 
 
64
  from torchvision.io.video import read_video
65
  from omegaconf import OmegaConf
 
66
 
67
 
68
  class SeedVrManager:
 
69
  def __init__(self, workspace_dir="deformes_workspace"):
70
+ # ... (sem alterações aqui)
71
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
72
  self.runner = None
73
  self.workspace_dir = workspace_dir
 
75
  self._original_barrier = None
76
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
77
 
78
+ # <--- INÍCIO DA NOVA FUNÇÃO DE PATCH --->
79
+ def _patch_config_paths(self):
80
+ """
81
+ Copies the VAE config directory from the cloned repo to the hardcoded
82
+ path that the SeedVR library expects.
83
+ """
84
+ app_root = Path("/home/user/app")
85
+ source_config_dir = SEEDVR_SPACE_DIR / "models" / "video_vae_v3"
86
+ target_config_parent_dir = app_root / "models"
87
+ target_config_dir = target_config_parent_dir / "video_vae_v3"
88
+
89
+ if not source_config_dir.exists():
90
+ logger.warning(f"Source VAE config directory not found at {source_config_dir}. Skipping patch.")
91
+ return
92
+
93
+ if target_config_dir.exists():
94
+ logger.info(f"Target VAE config path {target_config_dir} already exists. Skipping copy.")
95
+ return
96
+
97
+ logger.info(f"Patching SeedVR config path: Copying {source_config_dir} to {target_config_dir}...")
98
+ try:
99
+ # Cria o diretório pai (/home/user/app/models) se ele não existir
100
+ target_config_parent_dir.mkdir(parents=True, exist_ok=True)
101
+ # Copia a árvore de diretórios inteira
102
+ shutil.copytree(source_config_dir, target_config_dir)
103
+ logger.info("Config path patched successfully.")
104
+ except Exception as e:
105
+ logger.error(f"Failed to patch SeedVR config path: {e}", exc_info=True)
106
+ raise IOError("Could not patch the required SeedVR configuration paths.")
107
+ # <--- FIM DA NOVA FUNÇÃO DE PATCH --->
108
+
109
  def _download_models(self):
110
+ # ... (sem alterações aqui)
111
  logger.info("Verifying and downloading SeedVR2 model checkpoints...")
112
  ckpt_dir = SEEDVR_SPACE_DIR / 'ckpts'
113
  ckpt_dir.mkdir(exist_ok=True)
 
126
  def _initialize_runner(self, model_version: str):
127
  """Loads and configures the SeedVR model."""
128
  if self.runner is not None: return
129
+
130
+ # Chama o patch ANTES de tentar carregar qualquer coisa
131
+ self._patch_config_paths()
132
+
133
+ self._download_models()
134
 
135
  if dist.is_available() and not dist.is_initialized():
136
+ # ... (patch do barrier sem alterações)
137
  logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
138
  self._original_barrier = dist.barrier
139
  dist.barrier = lambda *args, **kwargs: None
 
148
  else:
149
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
150
 
151
+ # Agora, quando `load_config` for chamado, ele encontrará o arquivo no caminho esperado.
152
  config = load_config(str(config_path))
153
 
154
  self.runner = VideoDiffusionInfer(config)
 
160
  self.is_initialized = True
161
  logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
162
 
163
+ # ... (o resto da classe e do arquivo permanece o mesmo)
164
  def _unload_runner(self):
 
165
  if self.runner is not None:
166
  del self.runner; self.runner = None
167
  gc.collect(); torch.cuda.empty_cache()
 
175
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
176
  model_version: str = '3B', steps: int = 50, seed: int = 666,
177
  progress: gr.Progress = None) -> str:
 
178
  try:
179
  self._initialize_runner(model_version)
180
  set_seed(seed, same_across_ranks=True)
 
218
  final_sample_np = final_sample.to(torch.uint8).cpu().numpy()
219
  mediapy.write_video(output_video_path, final_sample_np, fps=24)
220
  logger.info(f"HD Mastered video saved to: {output_video_path}")
221
+ return output_path
222
  finally:
223
  self._unload_runner()
224
 
 
225
  def _load_file_from_url(url, model_dir='./', file_name=None):
226
  os.makedirs(model_dir, exist_ok=True)
227
  filename = file_name or os.path.basename(urlparse(url).path)
 
231
  download_url_to_file(url, cached_file, hash_prefix=None, progress=True)
232
  return cached_file
233
 
 
234
  seedvr_manager_singleton = SeedVrManager()