Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -16,23 +16,23 @@ import subprocess
|
|
| 16 |
import os
|
| 17 |
import sys
|
| 18 |
|
| 19 |
-
# --- Setup: Clone repository and
|
| 20 |
-
# This
|
| 21 |
|
| 22 |
# 1. Clone the repository with all its files
|
| 23 |
subprocess.run("git lfs install", shell=True, check=True)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
-
# 2.
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
print(f"Repository directory '{os.path.abspath(repo_dir)}' added to Python path.")
|
| 33 |
|
| 34 |
# --- Main Application Code ---
|
| 35 |
-
#
|
| 36 |
|
| 37 |
import torch
|
| 38 |
import mediapy
|
|
@@ -51,7 +51,7 @@ import torchvision.transforms as T
|
|
| 51 |
from torchvision.transforms import Compose, Lambda, Normalize
|
| 52 |
from torchvision.io.video import read_video
|
| 53 |
|
| 54 |
-
# Imports from the
|
| 55 |
from data.image.transforms.divisible_crop import DivisibleCrop
|
| 56 |
from data.image.transforms.na_resize import NaResize
|
| 57 |
from data.video.transforms.rearrange import Rearrange
|
|
@@ -63,9 +63,8 @@ from common.partition import partition_by_size
|
|
| 63 |
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
|
| 64 |
from common.distributed.ops import sync_data
|
| 65 |
|
| 66 |
-
# Check for color_fix utility
|
| 67 |
-
|
| 68 |
-
if os.path.exists(color_fix_path):
|
| 69 |
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
|
| 70 |
use_colorfix = True
|
| 71 |
else:
|
|
@@ -78,7 +77,7 @@ os.environ["MASTER_PORT"] = "12355"
|
|
| 78 |
os.environ["RANK"] = str(0)
|
| 79 |
os.environ["WORLD_SIZE"] = str(1)
|
| 80 |
|
| 81 |
-
#
|
| 82 |
python_executable = sys.executable
|
| 83 |
subprocess.run(
|
| 84 |
[python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
|
|
@@ -86,9 +85,8 @@ subprocess.run(
|
|
| 86 |
check=True
|
| 87 |
)
|
| 88 |
|
| 89 |
-
apex_wheel_path =
|
| 90 |
if os.path.exists(apex_wheel_path):
|
| 91 |
-
# CORREÇÃO: Usar sys.executable aqui também
|
| 92 |
subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
|
| 93 |
print("✅ Apex setup completed.")
|
| 94 |
|
|
@@ -99,10 +97,11 @@ def configure_sequence_parallel(sp_size):
|
|
| 99 |
init_sequence_parallel(sp_size)
|
| 100 |
|
| 101 |
def configure_runner(sp_size):
|
| 102 |
-
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
-
config = load_config(config_path)
|
| 106 |
runner = VideoDiffusionInfer(config)
|
| 107 |
OmegaConf.set_readonly(runner.config, False)
|
| 108 |
|
|
@@ -158,8 +157,9 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
| 158 |
def _extract_text_embeds():
|
| 159 |
positive_prompts_embeds = []
|
| 160 |
for _ in original_videos_local:
|
| 161 |
-
|
| 162 |
-
|
|
|
|
| 163 |
positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
|
| 164 |
gc.collect()
|
| 165 |
torch.cuda.empty_cache()
|
|
@@ -276,13 +276,12 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
|
|
| 276 |
# --- Gradio UI ---
|
| 277 |
|
| 278 |
with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
|
| 279 |
-
|
|
|
|
| 280 |
gr.HTML(f"""
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
</
|
| 284 |
-
<p><b>Official Gradio demo</b> for <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
|
| 285 |
-
🔥 <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.</p>
|
| 286 |
""")
|
| 287 |
|
| 288 |
with gr.Row():
|
|
|
|
| 16 |
import os
|
| 17 |
import sys
|
| 18 |
|
| 19 |
+
# --- Setup: Clone repository and Change Working Directory ---
|
| 20 |
+
# This is the most robust way to ensure all relative paths work correctly.
|
| 21 |
|
| 22 |
# 1. Clone the repository with all its files
|
| 23 |
subprocess.run("git lfs install", shell=True, check=True)
|
| 24 |
+
repo_dir_name = "SeedVR2-3B"
|
| 25 |
+
if not os.path.exists(repo_dir_name):
|
| 26 |
+
print(f"Cloning {repo_dir_name} repository...")
|
| 27 |
+
subprocess.run(f"git clone https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
|
| 28 |
|
| 29 |
+
# 2. Change the current working directory to the repository's root
|
| 30 |
+
# CORREÇÃO PRINCIPAL: Isso resolve todos os problemas de caminho relativo.
|
| 31 |
+
os.chdir(repo_dir_name)
|
| 32 |
+
print(f"Changed working directory to: {os.getcwd()}")
|
|
|
|
| 33 |
|
| 34 |
# --- Main Application Code ---
|
| 35 |
+
# Now that we are inside the repo, all imports and file loads will work naturally.
|
| 36 |
|
| 37 |
import torch
|
| 38 |
import mediapy
|
|
|
|
| 51 |
from torchvision.transforms import Compose, Lambda, Normalize
|
| 52 |
from torchvision.io.video import read_video
|
| 53 |
|
| 54 |
+
# Imports from the repository (will now work directly)
|
| 55 |
from data.image.transforms.divisible_crop import DivisibleCrop
|
| 56 |
from data.image.transforms.na_resize import NaResize
|
| 57 |
from data.video.transforms.rearrange import Rearrange
|
|
|
|
| 63 |
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
|
| 64 |
from common.distributed.ops import sync_data
|
| 65 |
|
| 66 |
+
# Check for color_fix utility (using relative path)
|
| 67 |
+
if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
|
|
|
|
| 68 |
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
|
| 69 |
use_colorfix = True
|
| 70 |
else:
|
|
|
|
| 77 |
os.environ["RANK"] = str(0)
|
| 78 |
os.environ["WORLD_SIZE"] = str(1)
|
| 79 |
|
| 80 |
+
# Use sys.executable to ensure we use the correct pip
|
| 81 |
python_executable = sys.executable
|
| 82 |
subprocess.run(
|
| 83 |
[python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
|
|
|
|
| 85 |
check=True
|
| 86 |
)
|
| 87 |
|
| 88 |
+
apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
|
| 89 |
if os.path.exists(apex_wheel_path):
|
|
|
|
| 90 |
subprocess.run([python_executable, "-m", "pip", "install", apex_wheel_path], check=True)
|
| 91 |
print("✅ Apex setup completed.")
|
| 92 |
|
|
|
|
| 97 |
init_sequence_parallel(sp_size)
|
| 98 |
|
| 99 |
def configure_runner(sp_size):
|
| 100 |
+
# Paths are now simple and relative to the repo root
|
| 101 |
+
config_path = 'configs_3b/main.yaml'
|
| 102 |
+
checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
|
| 103 |
|
| 104 |
+
config = load_config(config_path) # This will now work correctly
|
| 105 |
runner = VideoDiffusionInfer(config)
|
| 106 |
OmegaConf.set_readonly(runner.config, False)
|
| 107 |
|
|
|
|
| 157 |
def _extract_text_embeds():
|
| 158 |
positive_prompts_embeds = []
|
| 159 |
for _ in original_videos_local:
|
| 160 |
+
# Paths are now simple
|
| 161 |
+
text_pos_embeds = torch.load('pos_emb.pt')
|
| 162 |
+
text_neg_embeds = torch.load('neg_emb.pt')
|
| 163 |
positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
|
| 164 |
gc.collect()
|
| 165 |
torch.cuda.empty_cache()
|
|
|
|
| 276 |
# --- Gradio UI ---
|
| 277 |
|
| 278 |
with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
|
| 279 |
+
# Use an absolute path for the Gradio file source to be safe
|
| 280 |
+
logo_path = os.path.abspath("assets/seedvr_logo.png")
|
| 281 |
gr.HTML(f"""
|
| 282 |
+
|
| 283 |
+
<a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
|
| 284 |
+
🔥 <b>SeedVR2</b> is a one-step image and video restoration algorithm for real-world and AIGC content.
|
|
|
|
|
|
|
| 285 |
""")
|
| 286 |
|
| 287 |
with gr.Row():
|