Aduc-sdr commited on
Commit
4f657cb
·
verified ·
1 Parent(s): 8959bc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -193
app.py CHANGED
@@ -15,35 +15,42 @@ import spaces
15
  import subprocess
16
  import os
17
  import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  import torch
19
- import mediapy
20
- from einops import rearrange
21
- from omegaconf import OmegaConf
22
- import datetime
23
- from tqdm import tqdm
24
- import gc
25
- from PIL import Image
26
- import gradio as gr
27
  from pathlib import Path
28
  from urllib.parse import urlparse
29
  from torch.hub import download_url_to_file, get_dir
30
  import shlex
31
- import uuid
32
- import mimetypes
33
- import torchvision.transforms as T
34
- from torchvision.transforms import Compose, Lambda, Normalize
35
- from torchvision.io.video import read_video
36
-
37
- # --- Lógica de Download de Arquivos (do script original) ---
38
 
 
39
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
40
- """Carrega um arquivo de um URL http, baixando modelos se necessário."""
41
  if model_dir is None:
42
  hub_dir = get_dir()
43
  model_dir = os.path.join(hub_dir, 'checkpoints')
44
-
45
  os.makedirs(model_dir, exist_ok=True)
46
-
47
  parts = urlparse(url)
48
  filename = os.path.basename(parts.path)
49
  if file_name is not None:
@@ -54,10 +61,7 @@ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
54
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
55
  return cached_file
56
 
57
- ckpt_dir = Path('./ckpts')
58
- if not ckpt_dir.exists():
59
- ckpt_dir.mkdir()
60
- https://github.com/ByteDance-Seed/SeedVR/blob/6b061f4670599178df207d85a52a8212dd37c541/pos_emb.pt
61
  pretrain_model_url = {
62
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/ema_vae.pth',
63
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/seedvr_ema_7b.pth',
@@ -66,65 +70,55 @@ pretrain_model_url = {
66
  'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
67
  }
68
 
69
- # Baixa os pesos e dependências se não existirem
70
- if not os.path.exists('./ckpts/seedvr2_ema_7b.pth'):
71
- load_file_from_url(url=pretrain_model_url['dit'], model_dir='./ckpts/', progress=True)
72
- if not os.path.exists('./ckpts/ema_vae.pth'):
73
- load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/', progress=True)
74
- if not os.path.exists('./pos_emb.pt'):
75
- load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir='./', progress=True)
76
- if not os.path.exists('./neg_emb.pt'):
77
- load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir='./', progress=True)
78
- if not os.path.exists('./apex-0.1-cp310-cp310-linux_x86_64.whl'):
79
- load_file_from_url(url=pretrain_model_url['apex'], model_dir='./', progress=True)
80
-
81
- # Baixa os vídeos de exemplo
82
- torch.hub.download_url_to_file(
83
- 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4',
84
- '01.mp4')
85
- torch.hub.download_url_to_file(
86
- 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4',
87
- '02.mp4')
88
- torch.hub.download_url_to_file(
89
- 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
90
- '03.mp4')
91
-
92
-
93
- # --- Configuração de Ambiente e Dependências ---
94
-
95
- os.environ["MASTER_ADDR"] = "127.0.0.1"
96
- os.environ["MASTER_PORT"] = "12355"
97
- os.environ["RANK"] = str(0)
98
- os.environ["WORLD_SIZE"] = str(1)
99
-
100
  python_executable = sys.executable
101
- subprocess.run(
102
- [python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
103
- env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
104
- check=True
105
- )
106
 
107
  apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
108
  if os.path.exists(apex_wheel_path):
109
  print("Instalando o Apex a partir do arquivo wheel...")
110
- subprocess.run(
111
- [python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", apex_wheel_path],
112
- check=True
113
- )
114
  print("✅ Configuração do Apex concluída.")
 
 
115
 
116
- # --- Código Principal da Aplicação ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  from data.image.transforms.divisible_crop import DivisibleCrop
119
  from data.image.transforms.na_resize import NaResize
120
  from data.video.transforms.rearrange import Rearrange
121
- if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
122
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
123
- use_colorfix=True
124
- else:
125
-
126
- use_colorfix = False
127
- print('Atenção!!!!!! A correção de cor não está disponível!')
128
  from common.config import load_config
129
  from common.distributed import init_torch
130
  from common.distributed.advanced import init_sequence_parallel
@@ -133,23 +127,31 @@ from common.partition import partition_by_size
133
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
134
  from common.distributed.ops import sync_data
135
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def configure_sequence_parallel(sp_size):
137
  if sp_size > 1:
138
  init_sequence_parallel(sp_size)
139
 
140
  def configure_runner(sp_size):
141
  config_path = 'configs_7b/main.yaml'
142
- checkpoint_path = 'ckpts/seedvr2_ema_7b.pth'
143
-
144
  config = load_config(config_path)
145
  runner = VideoDiffusionInfer(config)
146
  OmegaConf.set_readonly(runner.config, False)
147
-
148
  init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
149
  configure_sequence_parallel(sp_size)
150
- runner.configure_dit_model(device="cuda", checkpoint=checkpoint_path)
151
  runner.configure_vae_model()
152
-
153
  if hasattr(runner.vae, "set_memory_limit"):
154
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
155
  return runner
@@ -160,34 +162,23 @@ def generation_step(runner, text_embeds_dict, cond_latents):
160
 
161
  noises = [torch.randn_like(latent) for latent in cond_latents]
162
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
163
- print(f"Gerando com o formato de ruído: {noises[0].size()}.")
164
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
165
  noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
166
- cond_noise_scale = 0.1
167
-
168
  def _add_noise(x, aug_noise):
169
- t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
170
  shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
171
  t = runner.timestep_transform(t, shape)
172
- print(f"Deslocamento de Timestep de {1000.0 * cond_noise_scale} para {t}.")
173
- x = runner.schedule.forward(x, aug_noise, t)
174
- return x
175
 
176
- conditions = [
177
- runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
178
- for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
179
- ]
180
 
181
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
182
- video_tensors = runner.inference(
183
- noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict
184
- )
185
 
186
- samples = [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
187
- del video_tensors
188
- return samples
189
 
190
- @spaces.GPU
191
  def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
192
  if video_path is None:
193
  return None, None, None
@@ -197,41 +188,19 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
197
  def _extract_text_embeds():
198
  positive_prompts_embeds = []
199
  for _ in original_videos_local:
200
- text_pos_embeds = torch.load('pos_emb.pt')
201
- text_neg_embeds = torch.load('neg_emb.pt')
202
- positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
203
- gc.collect()
204
- torch.cuda.empty_cache()
205
  return positive_prompts_embeds
206
 
207
- def cut_videos(videos, sp_size):
208
- if videos.size(1) > 121:
209
- videos = videos[:, :121]
210
- t = videos.size(1)
211
- if t <= 4 * sp_size:
212
- padding_needed = 4 * sp_size - t + 1
213
- if padding_needed > 0:
214
- padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
215
- videos = torch.cat([videos, padding], dim=1)
216
- return videos
217
- if (t - 1) % (4 * sp_size) == 0:
218
- return videos
219
- else:
220
- padding_needed = 4 * sp_size - ((t - 1) % (4 * sp_size))
221
- padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
222
- videos = torch.cat([videos, padding], dim=1)
223
- assert (videos.size(1) - 1) % (4 * sp_size) == 0
224
- return videos
225
-
226
  runner.config.diffusion.cfg.scale = cfg_scale
227
  runner.config.diffusion.cfg.rescale = cfg_rescale
228
  runner.config.diffusion.timesteps.sampling.steps = sample_steps
229
  runner.configure_diffusion()
230
-
231
- seed = int(seed) % (2**32)
232
- set_seed(seed, same_across_ranks=True)
233
- output_base_dir = "output"
234
- os.makedirs(output_base_dir, exist_ok=True)
235
 
236
  original_videos = [os.path.basename(video_path)]
237
  original_videos_local = partition_by_size(original_videos, batch_size)
@@ -240,92 +209,58 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
240
  video_transform = Compose([
241
  NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
242
  Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
243
- DivisibleCrop((16, 16)),
244
- Normalize(0.5, 0.5),
245
- Rearrange("t c h w -> c t h w"),
246
  ])
247
 
248
  for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
249
- cond_latents = []
250
- for _ in videos:
251
- media_type, _ = mimetypes.guess_type(video_path)
252
- is_image = media_type and media_type.startswith("image")
253
- is_video = media_type and media_type.startswith("video")
 
 
 
 
 
254
 
255
- if is_video:
256
- video, _, _ = read_video(video_path, output_format="TCHW")
257
- video = video / 255.0
258
- if video.size(0) > 121:
259
- video = video[:121]
260
- print(f"Tamanho do vídeo lido: {video.size()}")
261
- output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
262
- elif is_image:
263
- img = Image.open(video_path).convert("RGB")
264
- img_tensor = T.ToTensor()(img).unsqueeze(0)
265
- video = img_tensor
266
- print(f"Tamanho da imagem lida: {video.size()}")
267
- output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
268
- else:
269
- raise ValueError("Tipo de arquivo não suportado")
270
-
271
- cond_latents.append(video_transform(video.to(torch.device("cuda"))))
272
-
273
  ori_lengths = [v.size(1) for v in cond_latents]
274
- input_videos = cond_latents
275
- if is_video:
276
- cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
277
-
278
- print(f"Codificando vídeos: {[v.size() for v in cond_latents]}")
279
  cond_latents = runner.vae_encode(cond_latents)
280
 
281
- for i, emb in enumerate(text_embeds["texts_pos"]):
282
- text_embeds["texts_pos"][i] = emb.to(torch.device("cuda"))
283
- for i, emb in enumerate(text_embeds["texts_neg"]):
284
- text_embeds["texts_neg"][i] = emb.to(torch.device("cuda"))
285
 
286
  samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
287
  del cond_latents
288
 
289
- for _, input_tensor, sample, ori_length in zip(videos, input_videos, samples, ori_lengths):
290
- if ori_length < sample.shape[0]:
291
- sample = sample[:ori_length]
292
-
293
- input_tensor = rearrange(input_tensor, "c t h w -> t c h w")
294
- if use_colorfix:
295
- sample = wavelet_reconstruction(sample.to("cpu"), input_tensor[:sample.size(0)].to("cpu"))
296
  else:
297
- sample = sample.to("cpu")
298
-
299
- sample = rearrange(sample, "t c h w -> t h w c")
300
- sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
301
- sample = sample.to(torch.uint8).numpy()
302
-
303
- if is_image:
304
  mediapy.write_image(output_dir, sample[0])
305
- else:
306
- mediapy.write_video(output_dir, sample, fps=fps_out)
307
-
308
- gc.collect()
309
- torch.cuda.empty_cache()
310
- if is_image:
311
- return output_dir, None, output_dir
312
- else:
313
- return None, output_dir, output_dir
314
 
315
- # --- UI do Gradio ---
 
316
 
317
  with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
318
- logo_path = "assets/seedvr_logo.png"
319
  gr.HTML(f"""
320
  <div style='text-align:center; margin-bottom: 10px;'>
321
- <img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
322
  </div>
323
- <p><b>Demonstração oficial do Gradio</b> para <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
324
- 🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.</p>
 
 
 
325
  """)
326
 
327
  with gr.Row():
328
- input_file = gr.File(label="Carregar imagem ou vídeo", type="filepath")
329
  with gr.Column():
330
  seed = gr.Number(label="Seed", value=666)
331
  fps = gr.Number(label="FPS de Saída (para vídeo)", value=24)
@@ -340,22 +275,25 @@ with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
340
 
341
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
342
 
343
- # Seção de Exemplos, que agora funcionará pois os vídeos são baixados
344
  gr.Examples(
345
  examples=[
346
- ["./01.mp4", 4, 24],
347
- ["./02.mp4", 4, 24],
348
- ["./03.mp4", 4, 24],
349
  ],
350
  inputs=[input_file, seed, fps]
351
  )
352
 
353
  gr.HTML("""
354
  <hr>
355
- <p>Se você achou o SeedVR útil, por favor ⭐ o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>:
356
- <a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
 
 
 
357
  <h4>Aviso</h4>
358
- <p>Esta demonstração suporta até <b>720p e 121 frames para vídeos ou imagens 2k</b>. Para outros casos de uso, verifique o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>.</p>
 
359
  <h4>Limitações</h4>
360
  <p>Pode falhar em degradações pesadas ou em clipes AIGC com pouco movimento, causando excesso de nitidez ou restauração inadequada.</p>
361
  """)
 
15
  import subprocess
16
  import os
17
  import sys
18
+
19
+ # --- ETAPA 1: Preparação do Ambiente ---
20
+ # Clonar o repositório para garantir que todas as pastas de código (data, common, etc.) existam.
21
+
22
+ repo_dir_name = "SeedVR2-7B"
23
+ if not os.path.exists(repo_dir_name):
24
+ print(f"Clonando o repositório {repo_dir_name} para obter todo o código-fonte...")
25
+ # Usamos --depth 1 para um clone mais rápido, já que não precisamos do histórico
26
+ subprocess.run(f"git clone --depth 1 https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
27
+
28
+ # --- ETAPA 2: Configuração dos Caminhos ---
29
+ # Mudar para o diretório do repositório e adicioná-lo ao path do Python.
30
+
31
+ # Mudar para o diretório do repositório. ESSENCIAL para caminhos de arquivos relativos.
32
+ os.chdir(repo_dir_name)
33
+ print(f"Diretório de trabalho alterado para: {os.getcwd()}")
34
+
35
+ # Adicionar o diretório ao sys.path. ESSENCIAL para as importações de módulos.
36
+ sys.path.insert(0, os.path.abspath('.'))
37
+ print(f"Diretório atual adicionado ao sys.path para importações.")
38
+
39
+ # --- ETAPA 3: Instalação de Dependências e Download de Modelos ---
40
+ # Agora que estamos no diretório correto, podemos prosseguir.
41
+
42
  import torch
 
 
 
 
 
 
 
 
43
  from pathlib import Path
44
  from urllib.parse import urlparse
45
  from torch.hub import download_url_to_file, get_dir
46
  import shlex
 
 
 
 
 
 
 
47
 
48
+ # Função de download do original
49
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
 
50
  if model_dir is None:
51
  hub_dir = get_dir()
52
  model_dir = os.path.join(hub_dir, 'checkpoints')
 
53
  os.makedirs(model_dir, exist_ok=True)
 
54
  parts = urlparse(url)
55
  filename = os.path.basename(parts.path)
56
  if file_name is not None:
 
61
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
62
  return cached_file
63
 
64
+ # URLs dos modelos
 
 
 
65
  pretrain_model_url = {
66
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/ema_vae.pth',
67
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/seedvr_ema_7b.pth',
 
70
  'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
71
  }
72
 
73
+ # Criar diretório de checkpoints e baixar modelos
74
+ ckpt_dir = Path('./ckpts')
75
+ ckpt_dir.mkdir(exist_ok=True)
76
+
77
+ for key, url in pretrain_model_url.items():
78
+ filename = os.path.basename(url)
79
+ model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
80
+ target_path = os.path.join(model_dir, filename)
81
+ if not os.path.exists(target_path):
82
+ load_file_from_url(url=url, model_dir=model_dir, progress=True, file_name=filename)
83
+
84
+ # Baixar vídeos de exemplo
85
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4', '01.mp4')
86
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4', '02.mp4')
87
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4', '03.mp4')
88
+ torch.hub.download_url_to_file('https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl')
89
+
90
+ # Instalar dependências de forma robusta
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  python_executable = sys.executable
92
+ subprocess.run([python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"], env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, check=True)
 
 
 
 
93
 
94
  apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
95
  if os.path.exists(apex_wheel_path):
96
  print("Instalando o Apex a partir do arquivo wheel...")
97
+ subprocess.run([python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", apex_wheel_path], check=True)
 
 
 
98
  print("✅ Configuração do Apex concluída.")
99
+ else:
100
+ print(f"AVISO: O arquivo wheel do Apex '{apex_wheel_path}' não foi encontrado no repositório clonado.")
101
 
102
+ # --- ETAPA 4: Execução do Código Principal da Aplicação ---
103
+ # Agora que o ambiente está perfeito, importamos e executamos o resto do script.
104
+
105
+ import mediapy
106
+ from einops import rearrange
107
+ from omegaconf import OmegaConf
108
+ import datetime
109
+ from tqdm import tqdm
110
+ import gc
111
+ from PIL import Image
112
+ import gradio as gr
113
+ import uuid
114
+ import mimetypes
115
+ import torchvision.transforms as T
116
+ from torchvision.transforms import Compose, Lambda, Normalize
117
+ from torchvision.io.video import read_video
118
 
119
  from data.image.transforms.divisible_crop import DivisibleCrop
120
  from data.image.transforms.na_resize import NaResize
121
  from data.video.transforms.rearrange import Rearrange
 
 
 
 
 
 
 
122
  from common.config import load_config
123
  from common.distributed import init_torch
124
  from common.distributed.advanced import init_sequence_parallel
 
127
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
128
  from common.distributed.ops import sync_data
129
 
130
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
131
+ os.environ["MASTER_PORT"] = "12355"
132
+ os.environ["RANK"] = str(0)
133
+ os.environ["WORLD_SIZE"] = str(1)
134
+
135
+ if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
136
+ from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
137
+ use_colorfix = True
138
+ else:
139
+ use_colorfix = False
140
+ print('Atenção!!!!!! A correção de cor não está disponível!')
141
+
142
  def configure_sequence_parallel(sp_size):
143
  if sp_size > 1:
144
  init_sequence_parallel(sp_size)
145
 
146
  def configure_runner(sp_size):
147
  config_path = 'configs_7b/main.yaml'
 
 
148
  config = load_config(config_path)
149
  runner = VideoDiffusionInfer(config)
150
  OmegaConf.set_readonly(runner.config, False)
 
151
  init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
152
  configure_sequence_parallel(sp_size)
153
+ runner.configure_dit_model(device="cuda", checkpoint='ckpts/seedvr2_ema_7b.pth')
154
  runner.configure_vae_model()
 
155
  if hasattr(runner.vae, "set_memory_limit"):
156
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
157
  return runner
 
162
 
163
  noises = [torch.randn_like(latent) for latent in cond_latents]
164
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
 
165
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
166
  noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
167
+
 
168
  def _add_noise(x, aug_noise):
169
+ t = torch.tensor([1000.0], device=torch.device("cuda")) * 0.1
170
  shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
171
  t = runner.timestep_transform(t, shape)
172
+ return runner.schedule.forward(x, aug_noise, t)
 
 
173
 
174
+ conditions = [runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise)) for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)]
 
 
 
175
 
176
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
177
+ video_tensors = runner.inference(noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict)
178
+
179
+ return [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
180
 
 
 
 
181
 
 
182
  def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
183
  if video_path is None:
184
  return None, None, None
 
188
  def _extract_text_embeds():
189
  positive_prompts_embeds = []
190
  for _ in original_videos_local:
191
+ positive_prompts_embeds.append({
192
+ "texts_pos": [torch.load('pos_emb.pt')],
193
+ "texts_neg": [torch.load('neg_emb.pt')]
194
+ })
195
+ gc.collect(); torch.cuda.empty_cache()
196
  return positive_prompts_embeds
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  runner.config.diffusion.cfg.scale = cfg_scale
199
  runner.config.diffusion.cfg.rescale = cfg_rescale
200
  runner.config.diffusion.timesteps.sampling.steps = sample_steps
201
  runner.configure_diffusion()
202
+ set_seed(int(seed) % (2**32), same_across_ranks=True)
203
+ os.makedirs("output", exist_ok=True)
 
 
 
204
 
205
  original_videos = [os.path.basename(video_path)]
206
  original_videos_local = partition_by_size(original_videos, batch_size)
 
209
  video_transform = Compose([
210
  NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
211
  Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
212
+ DivisibleCrop((16, 16)), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w"),
 
 
213
  ])
214
 
215
  for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
216
+ media_type, _ = mimetypes.guess_type(video_path)
217
+ is_video = media_type and media_type.startswith("video")
218
+
219
+ if is_video:
220
+ video, _, _ = read_video(video_path, output_format="TCHW")
221
+ video = video[:121] / 255.0
222
+ output_dir = os.path.join("output", f"{uuid.uuid4()}.mp4")
223
+ else: # Assumimos que é uma imagem
224
+ video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
225
+ output_dir = os.path.join("output", f"{uuid.uuid4()}.png")
226
 
227
+ cond_latents = [video_transform(video.to("cuda"))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  ori_lengths = [v.size(1) for v in cond_latents]
 
 
 
 
 
229
  cond_latents = runner.vae_encode(cond_latents)
230
 
231
+ for key in ["texts_pos", "texts_neg"]:
232
+ for i, emb in enumerate(text_embeds[key]):
233
+ text_embeds[key][i] = emb.to("cuda")
 
234
 
235
  samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
236
  del cond_latents
237
 
238
+ for sample, ori_length in zip(samples, ori_lengths):
239
+ sample = sample[:ori_length].to("cpu")
240
+ sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
241
+
242
+ if is_video:
243
+ mediapy.write_video(output_dir, sample, fps=fps_out)
 
244
  else:
 
 
 
 
 
 
 
245
  mediapy.write_image(output_dir, sample[0])
 
 
 
 
 
 
 
 
 
246
 
247
+ gc.collect(); torch.cuda.empty_cache()
248
+ return (None, output_dir, output_dir) if is_video else (output_dir, None, output_dir)
249
 
250
  with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
 
251
  gr.HTML(f"""
252
  <div style='text-align:center; margin-bottom: 10px;'>
253
+ <img src='file/{os.path.abspath("assets/seedvr_logo.png")}' style='height:40px;' alt='SeedVR logo'/>
254
  </div>
255
+ <p><b>Demonstração oficial do Gradio</b> para
256
+ <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
257
+ <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
258
+ 🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.
259
+ </p>
260
  """)
261
 
262
  with gr.Row():
263
+ input_file = gr.File(label="Carregar imagem ou vídeo")
264
  with gr.Column():
265
  seed = gr.Number(label="Seed", value=666)
266
  fps = gr.Number(label="FPS de Saída (para vídeo)", value=24)
 
275
 
276
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
277
 
 
278
  gr.Examples(
279
  examples=[
280
+ ["01.mp4", 4, 24],
281
+ ["02.mp4", 4, 24],
282
+ ["03.mp4", 4, 24],
283
  ],
284
  inputs=[input_file, seed, fps]
285
  )
286
 
287
  gr.HTML("""
288
  <hr>
289
+ <p>Se você achou o SeedVR útil, por favor ⭐ o
290
+ <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>:</p>
291
+ <a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank">
292
+ <img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars">
293
+ </a>
294
  <h4>Aviso</h4>
295
+ <p>Esta demonstração suporta até <b>720p e 121 frames para vídeos ou imagens 2k</b>.
296
+ Para outros casos de uso, verifique o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>.</p>
297
  <h4>Limitações</h4>
298
  <p>Pode falhar em degradações pesadas ou em clipes AIGC com pouco movimento, causando excesso de nitidez ou restauração inadequada.</p>
299
  """)