Aduc-sdr commited on
Commit
2d69166
·
verified ·
1 Parent(s): 2b766ae

Update managers/seedvr_manager.py

Browse files
Files changed (1) hide show
  1. managers/seedvr_manager.py +56 -53
managers/seedvr_manager.py CHANGED
@@ -2,11 +2,11 @@
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 2.3.3
6
  #
7
- # This version adds a monkey patch to disable torch.distributed.barrier calls
8
- # within the SeedVR library, allowing it to run in a single-GPU inference mode
9
- # without raising a "process group not initialized" error.
10
 
11
  import torch
12
  import torch.distributed as dist
@@ -22,37 +22,46 @@ import gradio as gr
22
  import mediapy
23
  from einops import rearrange
24
 
 
25
  from tools.tensor_utils import wavelet_reconstruction
26
 
27
  logger = logging.getLogger(__name__)
28
 
29
  # --- Dependency Management ---
30
  DEPS_DIR = Path("./deps")
31
- SEEDVR_REPO_DIR = DEPS_DIR / "SeedVR"
32
- SEEDVR_REPO_URL = "https://github.com/ByteDance-Seed/SeedVR.git"
33
- VAE_CONFIG_URL = "https://raw.githubusercontent.com/ByteDance-Seed/SeedVR/main/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml"
 
34
 
35
  def setup_seedvr_dependencies():
36
- """Ensures the SeedVR repository is cloned and available in the sys.path."""
37
- if not SEEDVR_REPO_DIR.exists():
38
- logger.info(f"SeedVR repository not found at '{SEEDVR_REPO_DIR}'. Cloning from GitHub...")
 
 
39
  try:
40
  DEPS_DIR.mkdir(exist_ok=True)
41
- subprocess.run(["git", "clone", "--depth", "1", SEEDVR_REPO_URL, str(SEEDVR_REPO_DIR)], check=True, capture_output=True, text=True)
42
- logger.info("SeedVR repository cloned successfully.")
 
 
 
 
43
  except subprocess.CalledProcessError as e:
44
- logger.error(f"Failed to clone SeedVR repository. Git stderr: {e.stderr}")
45
- raise RuntimeError("Could not clone the required SeedVR dependency from GitHub.")
46
  else:
47
- logger.info("Found local SeedVR repository.")
48
 
49
- if str(SEEDVR_REPO_DIR.resolve()) not in sys.path:
50
- sys.path.insert(0, str(SEEDVR_REPO_DIR.resolve()))
51
- logger.info(f"Added '{SEEDVR_REPO_DIR.resolve()}' to sys.path.")
52
 
53
  setup_seedvr_dependencies()
54
 
55
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
 
56
  from common.config import load_config
57
  from common.seed import set_seed
58
  from data.image.transforms.divisible_crop import DivisibleCrop
@@ -62,7 +71,6 @@ from torchvision.transforms import Compose, Lambda, Normalize
62
  from torchvision.io.video import read_video
63
  from omegaconf import OmegaConf
64
 
65
-
66
  def _load_file_from_url(url, model_dir='./', file_name=None):
67
  os.makedirs(model_dir, exist_ok=True)
68
  filename = file_name or os.path.basename(urlparse(url).path)
@@ -79,17 +87,14 @@ class SeedVrManager:
79
  self.runner = None
80
  self.workspace_dir = workspace_dir
81
  self.is_initialized = False
82
- self._original_barrier = None # To store the original distributed barrier function
83
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
84
 
85
- def _download_models_and_configs(self):
86
- """Downloads the necessary checkpoints AND the missing VAE config file."""
87
- logger.info("Verifying and downloading SeedVR2 models and configs...")
88
- ckpt_dir = SEEDVR_REPO_DIR / 'ckpts'
89
- config_dir = SEEDVR_REPO_DIR / 'configs' / 'vae'
90
  ckpt_dir.mkdir(exist_ok=True)
91
- config_dir.mkdir(parents=True, exist_ok=True)
92
- _load_file_from_url(url=VAE_CONFIG_URL, model_dir=str(config_dir))
93
  pretrain_model_urls = {
94
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
95
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
@@ -99,55 +104,53 @@ class SeedVrManager:
99
  }
100
  for key, url in pretrain_model_urls.items():
101
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
102
- logger.info("SeedVR2 models and configs downloaded successfully.")
103
 
104
  def _initialize_runner(self, model_version: str):
105
- """Loads and configures the SeedVR model, with patches for single-GPU inference."""
106
  if self.runner is not None: return
107
- self._download_models_and_configs()
108
-
109
  if dist.is_available() and not dist.is_initialized():
110
  logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
111
  self._original_barrier = dist.barrier
112
  dist.barrier = lambda *args, **kwargs: None
113
 
114
- logger.info(f"Initializing SeedVR2 {model_version} runner...")
115
  if model_version == '3B':
116
- config_path = SEEDVR_REPO_DIR / 'configs_3b' / 'main.yaml'
117
- checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_3b.pth'
118
  elif model_version == '7B':
119
- config_path = SEEDVR_REPO_DIR / 'configs_7b' / 'main.yaml'
120
- checkpoint_path = SEEDVR_REPO_DIR / 'ckpts' / 'seedvr2_ema_7b.pth'
 
121
  else:
122
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
123
 
124
- try:
125
- config = load_config(str(config_path))
126
- except FileNotFoundError:
127
- logger.warning("Caught expected FileNotFoundError. Loading config manually.")
128
- config = OmegaConf.load(str(config_path))
129
- correct_vae_config_path = SEEDVR_REPO_DIR / 'configs' / 'vae' / 's8_c16_t4_inflation_sd3.yaml'
130
- vae_config = OmegaConf.load(str(correct_vae_config_path))
131
- config.vae = vae_config
132
- logger.info("Configuration loaded and patched manually.")
133
-
134
  self.runner = VideoDiffusionInfer(config)
135
  OmegaConf.set_readonly(self.runner.config, False)
136
- self.runner.configure_dit_model(device=self.device, checkpoint=str(checkpoint_path))
 
 
 
 
 
137
  self.runner.configure_vae_model()
 
138
  if hasattr(self.runner.vae, "set_memory_limit"):
139
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
140
  self.is_initialized = True
141
  logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
142
 
143
  def _unload_runner(self):
144
- """Unloads the runner from VRAM and restores any applied patches."""
145
  if self.runner is not None:
146
  del self.runner; self.runner = None
147
  gc.collect(); torch.cuda.empty_cache()
148
  self.is_initialized = False
149
  logger.info("SeedVR runner unloaded from VRAM.")
150
-
151
  if self._original_barrier is not None:
152
  logger.info("Restoring original torch.distributed.barrier function.")
153
  dist.barrier = self._original_barrier
@@ -156,7 +159,7 @@ class SeedVrManager:
156
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
157
  model_version: str = '3B', steps: int = 50, seed: int = 666,
158
  progress: gr.Progress = None) -> str:
159
- """Applies HD enhancement to a video using the SeedVR logic."""
160
  try:
161
  self._initialize_runner(model_version)
162
  set_seed(seed, same_across_ranks=True)
@@ -178,8 +181,8 @@ class SeedVrManager:
178
  cond_latents = self.runner.vae_encode(cond_latents)
179
  self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
180
  self.runner.dit.to(self.device)
181
- pos_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'pos_emb.pt'
182
- neg_emb_path = SEEDVR_REPO_DIR / 'ckpts' / 'neg_emb.pt'
183
  text_pos_embeds = torch.load(pos_emb_path).to(self.device)
184
  text_neg_embeds = torch.load(neg_emb_path).to(self.device)
185
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}
 
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 2.3.4
6
  #
7
+ # This version is optimized for Hugging Face Spaces environments. It now clones
8
+ # the dependency directly from the official SeedVR HF Space, which is faster,
9
+ # lighter, and more reliable than cloning from GitHub.
10
 
11
  import torch
12
  import torch.distributed as dist
 
22
  import mediapy
23
  from einops import rearrange
24
 
25
+ # Internalized utility for color correction, ensuring stability.
26
  from tools.tensor_utils import wavelet_reconstruction
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
  # --- Dependency Management ---
31
  DEPS_DIR = Path("./deps")
32
+ # Renamed to reflect the new source
33
+ SEEDVR_SPACE_DIR = DEPS_DIR / "SeedVR_Space"
34
+ # NEW: Cloning from the HF Space directly is much more efficient
35
+ SEEDVR_SPACE_URL = "https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B"
36
 
37
  def setup_seedvr_dependencies():
38
+ """
39
+ Ensures the SeedVR Space repository is cloned and available in the sys.path.
40
+ """
41
+ if not SEEDVR_SPACE_DIR.exists():
42
+ logger.info(f"SeedVR Space not found at '{SEEDVR_SPACE_DIR}'. Cloning from Hugging Face...")
43
  try:
44
  DEPS_DIR.mkdir(exist_ok=True)
45
+ # We clone the entire space repo to get its file structure
46
+ subprocess.run(
47
+ ["git", "clone", SEEDVR_SPACE_URL, str(SEEDVR_SPACE_DIR)],
48
+ check=True, capture_output=True, text=True
49
+ )
50
+ logger.info("SeedVR Space cloned successfully.")
51
  except subprocess.CalledProcessError as e:
52
+ logger.error(f"Failed to clone SeedVR Space. Git stderr: {e.stderr}")
53
+ raise RuntimeError("Could not clone the required SeedVR dependency from Hugging Face.")
54
  else:
55
+ logger.info("Found local SeedVR Space repository.")
56
 
57
+ if str(SEEDVR_SPACE_DIR.resolve()) not in sys.path:
58
+ sys.path.insert(0, str(SEEDVR_SPACE_DIR.resolve()))
59
+ logger.info(f"Added '{SEEDVR_SPACE_DIR.resolve()}' to sys.path.")
60
 
61
  setup_seedvr_dependencies()
62
 
63
+ # The imports from a Space are often directly from the root
64
+ from infer import VideoDiffusionInfer
65
  from common.config import load_config
66
  from common.seed import set_seed
67
  from data.image.transforms.divisible_crop import DivisibleCrop
 
71
  from torchvision.io.video import read_video
72
  from omegaconf import OmegaConf
73
 
 
74
  def _load_file_from_url(url, model_dir='./', file_name=None):
75
  os.makedirs(model_dir, exist_ok=True)
76
  filename = file_name or os.path.basename(urlparse(url).path)
 
87
  self.runner = None
88
  self.workspace_dir = workspace_dir
89
  self.is_initialized = False
90
+ self._original_barrier = None
91
  logger.info("SeedVrManager initialized. Model will be loaded on demand.")
92
 
93
+ def _download_models(self):
94
+ """Downloads the necessary checkpoints for SeedVR2."""
95
+ logger.info("Verifying and downloading SeedVR2 models...")
96
+ ckpt_dir = SEEDVR_SPACE_DIR / 'ckpt' # Note: Path in Space repo might be different
 
97
  ckpt_dir.mkdir(exist_ok=True)
 
 
98
  pretrain_model_urls = {
99
  'vae_ckpt': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
100
  'dit_3b': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
 
104
  }
105
  for key, url in pretrain_model_urls.items():
106
  _load_file_from_url(url=url, model_dir=str(ckpt_dir))
107
+ logger.info("SeedVR2 models downloaded successfully.")
108
 
109
  def _initialize_runner(self, model_version: str):
110
+ """Loads and configures the SeedVR model."""
111
  if self.runner is not None: return
112
+ self._download_models()
113
+
114
  if dist.is_available() and not dist.is_initialized():
115
  logger.info("Applying patch to disable torch.distributed.barrier for single-GPU inference.")
116
  self._original_barrier = dist.barrier
117
  dist.barrier = lambda *args, **kwargs: None
118
 
119
+ logger.info(f"Initializing SeedVR2 {model_version} runner from Space repo...")
120
  if model_version == '3B':
121
+ config_path = SEEDVR_SPACE_DIR / 'configs' / 'generate.yaml' # Typical path in a Space
122
+ checkpoint_path = SEEDVR_SPACE_DIR / 'ckpt' / 'VINCIE-3B' / 'dit.pth'
123
  elif model_version == '7B':
124
+ # Assuming a similar structure for a 7B space if it existed
125
+ config_path = SEEDVR_SPACE_DIR / 'configs' / 'generate_7b.yaml'
126
+ checkpoint_path = SEEDVR_SPACE_DIR / 'ckpt' / 'VINCIE-7B' / 'dit.pth'
127
  else:
128
  raise ValueError(f"Unsupported SeedVR model version: {model_version}")
129
 
130
+ config = load_config(str(config_path))
131
+
 
 
 
 
 
 
 
 
132
  self.runner = VideoDiffusionInfer(config)
133
  OmegaConf.set_readonly(self.runner.config, False)
134
+ # Manually set the correct checkpoint paths since the config inside the space might be relative
135
+ self.runner.config.dit.checkpoint = str(checkpoint_path)
136
+ self.runner.config.vae.checkpoint = str(SEEDVR_SPACE_DIR / 'ckpt' / 'VINCIE-3B' / 'vae.pth')
137
+ self.runner.config.text.models[0].path = str(SEEDVR_SPACE_DIR / 'ckpt' / 'VINCIE-3B' / 'llm14b')
138
+
139
+ self.runner.configure_dit_model(device=self.device, checkpoint=self.runner.config.dit.checkpoint)
140
  self.runner.configure_vae_model()
141
+
142
  if hasattr(self.runner.vae, "set_memory_limit"):
143
  self.runner.vae.set_memory_limit(**self.runner.config.vae.memory_limit)
144
  self.is_initialized = True
145
  logger.info(f"Runner for SeedVR2 {model_version} initialized and ready.")
146
 
147
  def _unload_runner(self):
148
+ """Unloads the runner from VRAM and restores patches."""
149
  if self.runner is not None:
150
  del self.runner; self.runner = None
151
  gc.collect(); torch.cuda.empty_cache()
152
  self.is_initialized = False
153
  logger.info("SeedVR runner unloaded from VRAM.")
 
154
  if self._original_barrier is not None:
155
  logger.info("Restoring original torch.distributed.barrier function.")
156
  dist.barrier = self._original_barrier
 
159
  def process_video(self, input_video_path: str, output_video_path: str, prompt: str,
160
  model_version: str = '3B', steps: int = 50, seed: int = 666,
161
  progress: gr.Progress = None) -> str:
162
+ """Applies HD enhancement to a video."""
163
  try:
164
  self._initialize_runner(model_version)
165
  set_seed(seed, same_across_ranks=True)
 
181
  cond_latents = self.runner.vae_encode(cond_latents)
182
  self.runner.vae.to("cpu"); gc.collect(); torch.cuda.empty_cache()
183
  self.runner.dit.to(self.device)
184
+ pos_emb_path = SEEDVR_SPACE_DIR / 'ckpt' / 'pos_emb.pt'
185
+ neg_emb_path = SEEDVR_SPACE_DIR / 'ckpt' / 'neg_emb.pt'
186
  text_pos_embeds = torch.load(pos_emb_path).to(self.device)
187
  text_neg_embeds = torch.load(neg_emb_path).to(self.device)
188
  text_embeds_dict = {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}