Aduc-sdr commited on
Commit
0ea462c
·
verified ·
1 Parent(s): 9aefe9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -35
app.py CHANGED
@@ -2,7 +2,7 @@
2
  # //
3
  # // Licensed under the Apache License, Version 2.0 (the "License");
4
  # // you may not use this file except in compliance with the License.
5
- # // You may not obtain a copy of the License at
6
  # //
7
  # // http://www.apache.org/licenses/LICENSE-2.0
8
  # //
@@ -11,23 +11,12 @@
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
14
-
15
- import torch.distributed as dist
16
  import os
17
- import gc
18
- import logging
19
  import sys
20
- import subprocess
21
- from pathlib import Path
22
- from urllib.parse import urlparse
23
- from torch.hub import download_url_to_file
24
- import gradio as gr
25
- import mediapy
26
- from einops import rearrange
27
- import shutil
28
- from omegaconf import OmegaConf
29
 
30
- # --- ETAPA 1: Clonar o Repositório Oficial do GitHub ---
31
  repo_name = "SeedVR"
32
  if not os.path.exists(repo_name):
33
  print(f"Clonando o repositório {repo_name} do GitHub...")
@@ -37,14 +26,25 @@ if not os.path.exists(repo_name):
37
  os.chdir(repo_name)
38
  print(f"Diretório de trabalho alterado para: {os.getcwd()}")
39
 
40
- # Adicionar o diretório ao path do Python para que as importações funcionem
41
  sys.path.insert(0, os.path.abspath('.'))
42
  print(f"Diretório atual adicionado ao sys.path.")
43
 
44
- # --- ETAPA 3: Instalar Dependências Conforme as Instruções ---
45
  python_executable = sys.executable
46
- print("Instalando dependências do requirements.txt...")
47
- subprocess.run([python_executable, "-m", "pip", "install", "-r", "requirements.txt"], check=True)
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  print("Instalando flash-attn...")
50
  subprocess.run([python_executable, "-m", "pip", "install", "flash-attn==2.5.9.post1", "--no-build-isolation"], check=True)
@@ -53,7 +53,6 @@ from pathlib import Path
53
  from urllib.parse import urlparse
54
  from torch.hub import download_url_to_file, get_dir
55
 
56
- # Função auxiliar para downloads
57
  def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
58
  os.makedirs(model_dir, exist_ok=True)
59
  if not file_name:
@@ -65,7 +64,6 @@ def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
65
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
66
  return cached_file
67
 
68
- # Baixar e instalar Apex pré-compilado (crucial para o ambiente do Spaces)
69
  apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
70
  apex_wheel_path = load_file_from_url(url=apex_url)
71
  print("Instalando Apex a partir do wheel baixado...")
@@ -74,6 +72,8 @@ print("✅ Configuração do Apex concluída.")
74
 
75
  # --- ETAPA 4: Baixar os Modelos Pré-treinados ---
76
  print("Baixando modelos pré-treinados...")
 
 
77
  pretrain_model_url = {
78
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/ema_vae.pth',
79
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/seedvr_ema_7b.pth',
@@ -89,7 +89,6 @@ for key, url in pretrain_model_url.items():
89
 
90
 
91
  # --- ETAPA 5: Executar a Aplicação Principal ---
92
- import torch
93
  import mediapy
94
  from einops import rearrange
95
  from omegaconf import OmegaConf
@@ -109,9 +108,7 @@ from data.image.transforms.na_resize import NaResize
109
  from data.video.transforms.rearrange import Rearrange
110
  from common.config import load_config
111
  from common.distributed import init_torch
112
- from common.distributed.advanced import init_sequence_parallel
113
  from common.seed import set_seed
114
- from common.partition import partition_by_size
115
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
116
  from common.distributed.ops import sync_data
117
 
@@ -120,11 +117,9 @@ os.environ["MASTER_PORT"] = "12355"
120
  os.environ["RANK"] = str(0)
121
  os.environ["WORLD_SIZE"] = str(1)
122
 
123
- if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
 
124
  from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
125
- use_colorfix = True
126
- else:
127
- use_colorfix = False
128
 
129
  def configure_runner():
130
  config = load_config('configs_7b/main.yaml')
@@ -139,10 +134,9 @@ def configure_runner():
139
 
140
  def generation_step(runner, text_embeds_dict, cond_latents):
141
  def _move_to_cuda(x): return [i.to("cuda") for i in x]
142
- noises = [torch.randn_like(latent) for latent in cond_latents]
143
- aug_noises = [torch.randn_like(latent) for latent in cond_latents]
144
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
145
- noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
146
  def _add_noise(x, aug_noise):
147
  t = torch.tensor([100.0], device="cuda")
148
  shape = torch.tensor(x.shape[1:], device="cuda")[None]
@@ -157,13 +151,28 @@ def generation_step(runner, text_embeds_dict, cond_latents):
157
  def generation_loop(video_path, seed=666, fps_out=24):
158
  if video_path is None: return None, None, None
159
  runner = configure_runner()
160
- text_embeds = {"texts_pos": [torch.load('pos_emb.pt').to("cuda")], "texts_neg": [torch.load('neg_emb.pt').to("cuda")]}
 
 
 
 
161
  runner.configure_diffusion()
162
  set_seed(int(seed))
163
  os.makedirs("output", exist_ok=True)
164
- video_transform = Compose([NaResize(1024), DivisibleCrop(16), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w")])
 
 
 
 
 
 
 
 
 
 
165
  media_type, _ = mimetypes.guess_type(video_path)
166
  is_video = media_type and media_type.startswith("video")
 
167
  if is_video:
168
  video, _, _ = read_video(video_path, output_format="TCHW")
169
  video = video[:121] / 255.0
@@ -171,12 +180,14 @@ def generation_loop(video_path, seed=666, fps_out=24):
171
  else:
172
  video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
173
  output_path = os.path.join("output", f"{uuid.uuid4()}.png")
174
- cond_latents = [video_transform(video.to("cuda"))]
 
175
  ori_length = cond_latents[0].size(2)
176
  cond_latents = runner.vae_encode(cond_latents)
177
  samples = generation_step(runner, text_embeds, cond_latents)
178
  sample = samples[0][:ori_length].cpu()
179
  sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
 
180
  if is_video:
181
  mediapy.write_video(output_path, sample, fps=fps_out)
182
  return None, output_path, output_path
@@ -185,7 +196,14 @@ def generation_loop(video_path, seed=666, fps_out=24):
185
  return output_path, None, output_path
186
 
187
  with gr.Blocks(title="SeedVR") as demo:
188
- gr.HTML(f"""<div style='text-align:center; margin-bottom: 10px;'><img src='file/{os.path.abspath("assets/seedvr_logo.png")}' style='height:40px;'/></div>...""")
 
 
 
 
 
 
 
189
  with gr.Row():
190
  input_file = gr.File(label="Carregar Imagem ou Vídeo")
191
  with gr.Column():
 
2
  # //
3
  # // Licensed under the Apache License, Version 2.0 (the "License");
4
  # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
  # //
7
  # // http://www.apache.org/licenses/LICENSE-2.0
8
  # //
 
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
14
+ import spaces
15
+ import subprocess
16
  import os
 
 
17
  import sys
 
 
 
 
 
 
 
 
 
18
 
19
+ # --- ETAPA 1: Clonar o Repositório do GitHub ---
20
  repo_name = "SeedVR"
21
  if not os.path.exists(repo_name):
22
  print(f"Clonando o repositório {repo_name} do GitHub...")
 
26
  os.chdir(repo_name)
27
  print(f"Diretório de trabalho alterado para: {os.getcwd()}")
28
 
 
29
  sys.path.insert(0, os.path.abspath('.'))
30
  print(f"Diretório atual adicionado ao sys.path.")
31
 
32
+ # --- ETAPA 3: Instalar Dependências Corretamente ---
33
  python_executable = sys.executable
34
+
35
+ # CORREÇÃO: Forçar uma versão do NumPy < 2.0 para evitar conflitos de compatibilidade.
36
+ print("Instalando NumPy compatível...")
37
+ subprocess.run([python_executable, "-m", "pip", "install", "numpy<2.0"], check=True)
38
+
39
+ # Filtrar requirements.txt para evitar conflitos com torch/torchvision pré-instalados
40
+ print("Filtrando requirements.txt...")
41
+ with open("requirements.txt", "r") as f_in, open("filtered_requirements.txt", "w") as f_out:
42
+ for line in f_in:
43
+ if not line.strip().startswith(('torch', 'torchvision')):
44
+ f_out.write(line)
45
+
46
+ print("Instalando dependências filtradas...")
47
+ subprocess.run([python_executable, "-m", "pip", "install", "-r", "filtered_requirements.txt"], check=True)
48
 
49
  print("Instalando flash-attn...")
50
  subprocess.run([python_executable, "-m", "pip", "install", "flash-attn==2.5.9.post1", "--no-build-isolation"], check=True)
 
53
  from urllib.parse import urlparse
54
  from torch.hub import download_url_to_file, get_dir
55
 
 
56
  def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
57
  os.makedirs(model_dir, exist_ok=True)
58
  if not file_name:
 
64
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
65
  return cached_file
66
 
 
67
  apex_url = 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
68
  apex_wheel_path = load_file_from_url(url=apex_url)
69
  print("Instalando Apex a partir do wheel baixado...")
 
72
 
73
  # --- ETAPA 4: Baixar os Modelos Pré-treinados ---
74
  print("Baixando modelos pré-treinados...")
75
+ import torch
76
+
77
  pretrain_model_url = {
78
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/ema_vae.pth',
79
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/seedvr_ema_7b.pth',
 
89
 
90
 
91
  # --- ETAPA 5: Executar a Aplicação Principal ---
 
92
  import mediapy
93
  from einops import rearrange
94
  from omegaconf import OmegaConf
 
108
  from data.video.transforms.rearrange import Rearrange
109
  from common.config import load_config
110
  from common.distributed import init_torch
 
111
  from common.seed import set_seed
 
112
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
113
  from common.distributed.ops import sync_data
114
 
 
117
  os.environ["RANK"] = str(0)
118
  os.environ["WORLD_SIZE"] = str(1)
119
 
120
+ use_colorfix = os.path.exists("projects/video_diffusion_sr/color_fix.py")
121
+ if use_colorfix:
122
  from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
 
 
 
123
 
124
  def configure_runner():
125
  config = load_config('configs_7b/main.yaml')
 
134
 
135
  def generation_step(runner, text_embeds_dict, cond_latents):
136
  def _move_to_cuda(x): return [i.to("cuda") for i in x]
137
+ noises, aug_noises = [torch.randn_like(l) for l in cond_latents], [torch.randn_like(l) for l in cond_latents]
 
138
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
139
+ noises, aug_noises, cond_latents = map(_move_to_cuda, (noises, aug_noises, cond_latents))
140
  def _add_noise(x, aug_noise):
141
  t = torch.tensor([100.0], device="cuda")
142
  shape = torch.tensor(x.shape[1:], device="cuda")[None]
 
151
  def generation_loop(video_path, seed=666, fps_out=24):
152
  if video_path is None: return None, None, None
153
  runner = configure_runner()
154
+ # Adicionado `weights_only=True` para segurança e para suprimir o aviso
155
+ text_embeds = {
156
+ "texts_pos": [torch.load('pos_emb.pt', weights_only=True).to("cuda")],
157
+ "texts_neg": [torch.load('neg_emb.pt', weights_only=True).to("cuda")]
158
+ }
159
  runner.configure_diffusion()
160
  set_seed(int(seed))
161
  os.makedirs("output", exist_ok=True)
162
+
163
+ # CORREÇÃO: Fornecer os argumentos que faltam para NaResize.
164
+ res_h, res_w = 1280, 720
165
+ transform = Compose([
166
+ NaResize(resolution=(res_h * res_w)**0.5, mode="area", downsample_only=False),
167
+ Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
168
+ DivisibleCrop((16, 16)),
169
+ Normalize(0.5, 0.5),
170
+ Rearrange("t c h w -> c t h w")
171
+ ])
172
+
173
  media_type, _ = mimetypes.guess_type(video_path)
174
  is_video = media_type and media_type.startswith("video")
175
+
176
  if is_video:
177
  video, _, _ = read_video(video_path, output_format="TCHW")
178
  video = video[:121] / 255.0
 
180
  else:
181
  video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
182
  output_path = os.path.join("output", f"{uuid.uuid4()}.png")
183
+
184
+ cond_latents = [transform(video.to("cuda"))]
185
  ori_length = cond_latents[0].size(2)
186
  cond_latents = runner.vae_encode(cond_latents)
187
  samples = generation_step(runner, text_embeds, cond_latents)
188
  sample = samples[0][:ori_length].cpu()
189
  sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
190
+
191
  if is_video:
192
  mediapy.write_video(output_path, sample, fps=fps_out)
193
  return None, output_path, output_path
 
196
  return output_path, None, output_path
197
 
198
  with gr.Blocks(title="SeedVR") as demo:
199
+ gr.HTML(f"""
200
+
201
+ <p><b>Demonstração oficial do Gradio</b> para
202
+ <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
203
+ <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
204
+ 🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.
205
+ </p>
206
+ """)
207
  with gr.Row():
208
  input_file = gr.File(label="Carregar Imagem ou Vídeo")
209
  with gr.Column():