Aduc-sdr commited on
Commit
8ab9604
·
verified ·
1 Parent(s): 746b66d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -190
app.py CHANGED
@@ -15,35 +15,42 @@ import spaces
15
  import subprocess
16
  import os
17
  import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  import torch
19
- import mediapy
20
- from einops import rearrange
21
- from omegaconf import OmegaConf
22
- import datetime
23
- from tqdm import tqdm
24
- import gc
25
- from PIL import Image
26
- import gradio as gr
27
  from pathlib import Path
28
  from urllib.parse import urlparse
29
  from torch.hub import download_url_to_file, get_dir
30
  import shlex
31
- import uuid
32
- import mimetypes
33
- import torchvision.transforms as T
34
- from torchvision.transforms import Compose, Lambda, Normalize
35
- from torchvision.io.video import read_video
36
-
37
- # --- Lógica de Download de Arquivos (do script original) ---
38
 
 
39
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
40
- """Carrega um arquivo de um URL http, baixando modelos se necessário."""
41
  if model_dir is None:
42
  hub_dir = get_dir()
43
  model_dir = os.path.join(hub_dir, 'checkpoints')
44
-
45
  os.makedirs(model_dir, exist_ok=True)
46
-
47
  parts = urlparse(url)
48
  filename = os.path.basename(parts.path)
49
  if file_name is not None:
@@ -54,76 +61,62 @@ def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
54
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
55
  return cached_file
56
 
57
- ckpt_dir = Path('./ckpts')
58
- if not ckpt_dir.exists():
59
- ckpt_dir.mkdir()
60
-
61
  pretrain_model_url = {
62
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
63
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
64
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
65
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
66
- 'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp310-cp310-linux_x86_64.whl'
67
  }
68
 
69
- # Baixa os pesos e dependências se não existirem
70
- if not os.path.exists('./ckpts/seedvr2_ema_3b.pth'):
71
- load_file_from_url(url=pretrain_model_url['dit'], model_dir='./ckpts/', progress=True)
72
- if not os.path.exists('./ckpts/ema_vae.pth'):
73
- load_file_from_url(url=pretrain_model_url['vae'], model_dir='./ckpts/', progress=True)
74
- if not os.path.exists('./pos_emb.pt'):
75
- load_file_from_url(url=pretrain_model_url['pos_emb'], model_dir='./', progress=True)
76
- if not os.path.exists('./neg_emb.pt'):
77
- load_file_from_url(url=pretrain_model_url['neg_emb'], model_dir='./', progress=True)
78
- if not os.path.exists('./apex-0.1-cp310-cp310-linux_x86_64.whl'):
79
- load_file_from_url(url=pretrain_model_url['apex'], model_dir='./', progress=True)
80
-
81
- # Baixa os vídeos de exemplo
82
- torch.hub.download_url_to_file(
83
- 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4',
84
- '01.mp4')
85
- torch.hub.download_url_to_file(
86
- 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4',
87
- '02.mp4')
88
- torch.hub.download_url_to_file(
89
- 'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
90
- '03.mp4')
91
-
92
-
93
- # --- Configuração de Ambiente e Dependências ---
94
 
95
- os.environ["MASTER_ADDR"] = "127.0.0.1"
96
- os.environ["MASTER_PORT"] = "12355"
97
- os.environ["RANK"] = str(0)
98
- os.environ["WORLD_SIZE"] = str(1)
 
 
99
 
 
 
 
 
 
 
100
  python_executable = sys.executable
101
- subprocess.run(
102
- [python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
103
- env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
104
- check=True
105
- )
106
 
107
  apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
108
  if os.path.exists(apex_wheel_path):
109
  print("Instalando o Apex a partir do arquivo wheel...")
110
- subprocess.run(
111
- [python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", apex_wheel_path],
112
- check=True
113
- )
114
  print("✅ Configuração do Apex concluída.")
 
 
 
 
 
115
 
116
- # --- Código Principal da Aplicação ---
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  from data.image.transforms.divisible_crop import DivisibleCrop
119
  from data.image.transforms.na_resize import NaResize
120
  from data.video.transforms.rearrange import Rearrange
121
- if os.path.exists("./projects/video_diffusion_sr/color_fix.py"):
122
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
123
- use_colorfix=True
124
- else:
125
- use_colorfix = False
126
- print('Atenção!!!!!! A correção de cor não está disponível!')
127
  from common.config import load_config
128
  from common.distributed import init_torch
129
  from common.distributed.advanced import init_sequence_parallel
@@ -132,23 +125,31 @@ from common.partition import partition_by_size
132
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
133
  from common.distributed.ops import sync_data
134
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def configure_sequence_parallel(sp_size):
136
  if sp_size > 1:
137
  init_sequence_parallel(sp_size)
138
 
139
  def configure_runner(sp_size):
140
  config_path = 'configs_3b/main.yaml'
141
- checkpoint_path = 'ckpts/seedvr2_ema_3b.pth'
142
-
143
  config = load_config(config_path)
144
  runner = VideoDiffusionInfer(config)
145
  OmegaConf.set_readonly(runner.config, False)
146
-
147
  init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
148
  configure_sequence_parallel(sp_size)
149
- runner.configure_dit_model(device="cuda", checkpoint=checkpoint_path)
150
  runner.configure_vae_model()
151
-
152
  if hasattr(runner.vae, "set_memory_limit"):
153
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
154
  return runner
@@ -159,32 +160,21 @@ def generation_step(runner, text_embeds_dict, cond_latents):
159
 
160
  noises = [torch.randn_like(latent) for latent in cond_latents]
161
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
162
- print(f"Gerando com o formato de ruído: {noises[0].size()}.")
163
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
164
  noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
165
- cond_noise_scale = 0.1
166
-
167
  def _add_noise(x, aug_noise):
168
- t = torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale
169
  shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
170
  t = runner.timestep_transform(t, shape)
171
- print(f"Deslocamento de Timestep de {1000.0 * cond_noise_scale} para {t}.")
172
- x = runner.schedule.forward(x, aug_noise, t)
173
- return x
174
 
175
- conditions = [
176
- runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
177
- for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
178
- ]
179
 
180
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
181
- video_tensors = runner.inference(
182
- noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict
183
- )
184
 
185
- samples = [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
186
- del video_tensors
187
- return samples
188
 
189
  @spaces.GPU
190
  def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
@@ -196,41 +186,19 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
196
  def _extract_text_embeds():
197
  positive_prompts_embeds = []
198
  for _ in original_videos_local:
199
- text_pos_embeds = torch.load('pos_emb.pt')
200
- text_neg_embeds = torch.load('neg_emb.pt')
201
- positive_prompts_embeds.append({"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]})
202
- gc.collect()
203
- torch.cuda.empty_cache()
204
  return positive_prompts_embeds
205
 
206
- def cut_videos(videos, sp_size):
207
- if videos.size(1) > 121:
208
- videos = videos[:, :121]
209
- t = videos.size(1)
210
- if t <= 4 * sp_size:
211
- padding_needed = 4 * sp_size - t + 1
212
- if padding_needed > 0:
213
- padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
214
- videos = torch.cat([videos, padding], dim=1)
215
- return videos
216
- if (t - 1) % (4 * sp_size) == 0:
217
- return videos
218
- else:
219
- padding_needed = 4 * sp_size - ((t - 1) % (4 * sp_size))
220
- padding = torch.cat([videos[:, -1].unsqueeze(1)] * padding_needed, dim=1)
221
- videos = torch.cat([videos, padding], dim=1)
222
- assert (videos.size(1) - 1) % (4 * sp_size) == 0
223
- return videos
224
-
225
  runner.config.diffusion.cfg.scale = cfg_scale
226
  runner.config.diffusion.cfg.rescale = cfg_rescale
227
  runner.config.diffusion.timesteps.sampling.steps = sample_steps
228
  runner.configure_diffusion()
229
-
230
- seed = int(seed) % (2**32)
231
- set_seed(seed, same_across_ranks=True)
232
- output_base_dir = "output"
233
- os.makedirs(output_base_dir, exist_ok=True)
234
 
235
  original_videos = [os.path.basename(video_path)]
236
  original_videos_local = partition_by_size(original_videos, batch_size)
@@ -239,92 +207,58 @@ def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.
239
  video_transform = Compose([
240
  NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
241
  Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
242
- DivisibleCrop((16, 16)),
243
- Normalize(0.5, 0.5),
244
- Rearrange("t c h w -> c t h w"),
245
  ])
246
 
247
  for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
248
- cond_latents = []
249
- for _ in videos:
250
- media_type, _ = mimetypes.guess_type(video_path)
251
- is_image = media_type and media_type.startswith("image")
252
- is_video = media_type and media_type.startswith("video")
 
 
 
 
 
253
 
254
- if is_video:
255
- video, _, _ = read_video(video_path, output_format="TCHW")
256
- video = video / 255.0
257
- if video.size(0) > 121:
258
- video = video[:121]
259
- print(f"Tamanho do vídeo lido: {video.size()}")
260
- output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.mp4")
261
- elif is_image:
262
- img = Image.open(video_path).convert("RGB")
263
- img_tensor = T.ToTensor()(img).unsqueeze(0)
264
- video = img_tensor
265
- print(f"Tamanho da imagem lida: {video.size()}")
266
- output_dir = os.path.join(output_base_dir, f"{uuid.uuid4()}.png")
267
- else:
268
- raise ValueError("Tipo de arquivo não suportado")
269
-
270
- cond_latents.append(video_transform(video.to(torch.device("cuda"))))
271
-
272
  ori_lengths = [v.size(1) for v in cond_latents]
273
- input_videos = cond_latents
274
- if is_video:
275
- cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
276
-
277
- print(f"Codificando vídeos: {[v.size() for v in cond_latents]}")
278
  cond_latents = runner.vae_encode(cond_latents)
279
 
280
- for i, emb in enumerate(text_embeds["texts_pos"]):
281
- text_embeds["texts_pos"][i] = emb.to(torch.device("cuda"))
282
- for i, emb in enumerate(text_embeds["texts_neg"]):
283
- text_embeds["texts_neg"][i] = emb.to(torch.device("cuda"))
284
 
285
  samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
286
  del cond_latents
287
 
288
- for _, input_tensor, sample, ori_length in zip(videos, input_videos, samples, ori_lengths):
289
- if ori_length < sample.shape[0]:
290
- sample = sample[:ori_length]
291
-
292
- input_tensor = rearrange(input_tensor, "c t h w -> t c h w")
293
- if use_colorfix:
294
- sample = wavelet_reconstruction(sample.to("cpu"), input_tensor[:sample.size(0)].to("cpu"))
295
  else:
296
- sample = sample.to("cpu")
297
-
298
- sample = rearrange(sample, "t c h w -> t h w c")
299
- sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
300
- sample = sample.to(torch.uint8).numpy()
301
-
302
- if is_image:
303
  mediapy.write_image(output_dir, sample[0])
304
- else:
305
- mediapy.write_video(output_dir, sample, fps=fps_out)
306
-
307
- gc.collect()
308
- torch.cuda.empty_cache()
309
- if is_image:
310
- return output_dir, None, output_dir
311
- else:
312
- return None, output_dir, output_dir
313
 
314
- # --- UI do Gradio ---
 
315
 
316
  with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
317
- logo_path = "assets/seedvr_logo.png"
318
  gr.HTML(f"""
319
  <div style='text-align:center; margin-bottom: 10px;'>
320
- <img src='file/{logo_path}' style='height:40px;' alt='SeedVR logo'/>
321
  </div>
322
- <p><b>Demonstração oficial do Gradio</b> para <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'><b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
323
- 🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.</p>
 
 
 
324
  """)
325
 
326
  with gr.Row():
327
- input_file = gr.File(label="Carregar imagem ou vídeo", type="filepath")
328
  with gr.Column():
329
  seed = gr.Number(label="Seed", value=666)
330
  fps = gr.Number(label="FPS de Saída (para vídeo)", value=24)
@@ -339,22 +273,25 @@ with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
339
 
340
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
341
 
342
- # Seção de Exemplos, que agora funcionará pois os vídeos são baixados
343
  gr.Examples(
344
  examples=[
345
- ["./01.mp4", 4, 24],
346
- ["./02.mp4", 4, 24],
347
- ["./03.mp4", 4, 24],
348
  ],
349
  inputs=[input_file, seed, fps]
350
  )
351
 
352
  gr.HTML("""
353
  <hr>
354
- <p>Se você achou o SeedVR útil, por favor ⭐ o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>:
355
- <a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank"><img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars"></a></p>
 
 
 
356
  <h4>Aviso</h4>
357
- <p>Esta demonstração suporta até <b>720p e 121 frames para vídeos ou imagens 2k</b>. Para outros casos de uso, verifique o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>.</p>
 
358
  <h4>Limitações</h4>
359
  <p>Pode falhar em degradações pesadas ou em clipes AIGC com pouco movimento, causando excesso de nitidez ou restauração inadequada.</p>
360
  """)
 
15
  import subprocess
16
  import os
17
  import sys
18
+
19
+ # --- ETAPA 1: Preparação do Ambiente ---
20
+ # Clonar o repositório para garantir que todas as pastas de código (data, common, etc.) existam.
21
+
22
+ repo_dir_name = "SeedVR2-3B"
23
+ if not os.path.exists(repo_dir_name):
24
+ print(f"Clonando o repositório {repo_dir_name} para obter todo o código-fonte...")
25
+ # Usamos --depth 1 para um clone mais rápido, já que não precisamos do histórico
26
+ subprocess.run(f"git clone --depth 1 https://huggingface.co/spaces/ByteDance-Seed/{repo_dir_name}", shell=True, check=True)
27
+
28
+ # --- ETAPA 2: Configuração dos Caminhos ---
29
+ # Mudar para o diretório do repositório e adicioná-lo ao path do Python.
30
+
31
+ # Mudar para o diretório do repositório. ESSENCIAL para caminhos de arquivos relativos.
32
+ os.chdir(repo_dir_name)
33
+ print(f"Diretório de trabalho alterado para: {os.getcwd()}")
34
+
35
+ # Adicionar o diretório ao sys.path. ESSENCIAL para as importações de módulos.
36
+ sys.path.insert(0, os.path.abspath('.'))
37
+ print(f"Diretório atual adicionado ao sys.path para importações.")
38
+
39
+ # --- ETAPA 3: Instalação de Dependências e Download de Modelos ---
40
+ # Agora que estamos no diretório correto, podemos prosseguir.
41
+
42
  import torch
 
 
 
 
 
 
 
 
43
  from pathlib import Path
44
  from urllib.parse import urlparse
45
  from torch.hub import download_url_to_file, get_dir
46
  import shlex
 
 
 
 
 
 
 
47
 
48
+ # Função de download do original
49
  def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
 
50
  if model_dir is None:
51
  hub_dir = get_dir()
52
  model_dir = os.path.join(hub_dir, 'checkpoints')
 
53
  os.makedirs(model_dir, exist_ok=True)
 
54
  parts = urlparse(url)
55
  filename = os.path.basename(parts.path)
56
  if file_name is not None:
 
61
  download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
62
  return cached_file
63
 
64
+ # URLs dos modelos
 
 
 
65
  pretrain_model_url = {
66
  'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
67
  'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
68
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
69
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
 
70
  }
71
 
72
+ # Criar diretório de checkpoints e baixar modelos
73
+ ckpt_dir = Path('./ckpts')
74
+ ckpt_dir.mkdir(exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ for key, url in pretrain_model_url.items():
77
+ filename = os.path.basename(url)
78
+ model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
79
+ target_path = os.path.join(model_dir, filename)
80
+ if not os.path.exists(target_path):
81
+ load_file_from_url(url=url, model_dir=model_dir, progress=True, file_name=filename)
82
 
83
+ # Baixar vídeos de exemplo
84
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4', '01.mp4')
85
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4', '02.mp4')
86
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4', '03.mp4')
87
+
88
+ # Instalar dependências de forma robusta
89
  python_executable = sys.executable
90
+ subprocess.run([python_executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"], env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, check=True)
 
 
 
 
91
 
92
  apex_wheel_path = "apex-0.1-cp310-cp310-linux_x86_64.whl"
93
  if os.path.exists(apex_wheel_path):
94
  print("Instalando o Apex a partir do arquivo wheel...")
95
+ subprocess.run([python_executable, "-m", "pip", "install", "--force-reinstall", "--no-cache-dir", apex_wheel_path], check=True)
 
 
 
96
  print("✅ Configuração do Apex concluída.")
97
+ else:
98
+ print(f"AVISO: O arquivo wheel do Apex '{apex_wheel_path}' não foi encontrado no repositório clonado.")
99
+
100
+ # --- ETAPA 4: Execução do Código Principal da Aplicação ---
101
+ # Agora que o ambiente está perfeito, importamos e executamos o resto do script.
102
 
103
+ import mediapy
104
+ from einops import rearrange
105
+ from omegaconf import OmegaConf
106
+ import datetime
107
+ from tqdm import tqdm
108
+ import gc
109
+ from PIL import Image
110
+ import gradio as gr
111
+ import uuid
112
+ import mimetypes
113
+ import torchvision.transforms as T
114
+ from torchvision.transforms import Compose, Lambda, Normalize
115
+ from torchvision.io.video import read_video
116
 
117
  from data.image.transforms.divisible_crop import DivisibleCrop
118
  from data.image.transforms.na_resize import NaResize
119
  from data.video.transforms.rearrange import Rearrange
 
 
 
 
 
 
120
  from common.config import load_config
121
  from common.distributed import init_torch
122
  from common.distributed.advanced import init_sequence_parallel
 
125
  from projects.video_diffusion_sr.infer import VideoDiffusionInfer
126
  from common.distributed.ops import sync_data
127
 
128
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
129
+ os.environ["MASTER_PORT"] = "12355"
130
+ os.environ["RANK"] = str(0)
131
+ os.environ["WORLD_SIZE"] = str(1)
132
+
133
+ if os.path.exists("projects/video_diffusion_sr/color_fix.py"):
134
+ from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
135
+ use_colorfix = True
136
+ else:
137
+ use_colorfix = False
138
+ print('Atenção!!!!!! A correção de cor não está disponível!')
139
+
140
  def configure_sequence_parallel(sp_size):
141
  if sp_size > 1:
142
  init_sequence_parallel(sp_size)
143
 
144
  def configure_runner(sp_size):
145
  config_path = 'configs_3b/main.yaml'
 
 
146
  config = load_config(config_path)
147
  runner = VideoDiffusionInfer(config)
148
  OmegaConf.set_readonly(runner.config, False)
 
149
  init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
150
  configure_sequence_parallel(sp_size)
151
+ runner.configure_dit_model(device="cuda", checkpoint='ckpts/seedvr2_ema_3b.pth')
152
  runner.configure_vae_model()
 
153
  if hasattr(runner.vae, "set_memory_limit"):
154
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
155
  return runner
 
160
 
161
  noises = [torch.randn_like(latent) for latent in cond_latents]
162
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
 
163
  noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
164
  noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
165
+
 
166
  def _add_noise(x, aug_noise):
167
+ t = torch.tensor([1000.0], device=torch.device("cuda")) * 0.1
168
  shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
169
  t = runner.timestep_transform(t, shape)
170
+ return runner.schedule.forward(x, aug_noise, t)
 
 
171
 
172
+ conditions = [runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise)) for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)]
 
 
 
173
 
174
  with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
175
+ video_tensors = runner.inference(noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict)
 
 
176
 
177
+ return [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
 
 
178
 
179
  @spaces.GPU
180
  def generation_loop(video_path, seed=666, fps_out=24, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
 
186
  def _extract_text_embeds():
187
  positive_prompts_embeds = []
188
  for _ in original_videos_local:
189
+ positive_prompts_embeds.append({
190
+ "texts_pos": [torch.load('pos_emb.pt')],
191
+ "texts_neg": [torch.load('neg_emb.pt')]
192
+ })
193
+ gc.collect(); torch.cuda.empty_cache()
194
  return positive_prompts_embeds
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  runner.config.diffusion.cfg.scale = cfg_scale
197
  runner.config.diffusion.cfg.rescale = cfg_rescale
198
  runner.config.diffusion.timesteps.sampling.steps = sample_steps
199
  runner.configure_diffusion()
200
+ set_seed(int(seed) % (2**32), same_across_ranks=True)
201
+ os.makedirs("output", exist_ok=True)
 
 
 
202
 
203
  original_videos = [os.path.basename(video_path)]
204
  original_videos_local = partition_by_size(original_videos, batch_size)
 
207
  video_transform = Compose([
208
  NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
209
  Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
210
+ DivisibleCrop((16, 16)), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w"),
 
 
211
  ])
212
 
213
  for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
214
+ media_type, _ = mimetypes.guess_type(video_path)
215
+ is_video = media_type and media_type.startswith("video")
216
+
217
+ if is_video:
218
+ video, _, _ = read_video(video_path, output_format="TCHW")
219
+ video = video[:121] / 255.0
220
+ output_dir = os.path.join("output", f"{uuid.uuid4()}.mp4")
221
+ else: # Assumimos que é uma imagem
222
+ video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
223
+ output_dir = os.path.join("output", f"{uuid.uuid4()}.png")
224
 
225
+ cond_latents = [video_transform(video.to("cuda"))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  ori_lengths = [v.size(1) for v in cond_latents]
 
 
 
 
 
227
  cond_latents = runner.vae_encode(cond_latents)
228
 
229
+ for key in ["texts_pos", "texts_neg"]:
230
+ for i, emb in enumerate(text_embeds[key]):
231
+ text_embeds[key][i] = emb.to("cuda")
 
232
 
233
  samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
234
  del cond_latents
235
 
236
+ for sample, ori_length in zip(samples, ori_lengths):
237
+ sample = sample[:ori_length].to("cpu")
238
+ sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
239
+
240
+ if is_video:
241
+ mediapy.write_video(output_dir, sample, fps=fps_out)
 
242
  else:
 
 
 
 
 
 
 
243
  mediapy.write_image(output_dir, sample[0])
 
 
 
 
 
 
 
 
 
244
 
245
+ gc.collect(); torch.cuda.empty_cache()
246
+ return (None, output_dir, output_dir) if is_video else (output_dir, None, output_dir)
247
 
248
  with gr.Blocks(title="SeedVR2: Restauração de Vídeo em Um Passo") as demo:
 
249
  gr.HTML(f"""
250
  <div style='text-align:center; margin-bottom: 10px;'>
251
+ <img src='file/{os.path.abspath("assets/seedvr_logo.png")}' style='height:40px;' alt='SeedVR logo'/>
252
  </div>
253
+ <p><b>Demonstração oficial do Gradio</b> para
254
+ <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
255
+ <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
256
+ 🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.
257
+ </p>
258
  """)
259
 
260
  with gr.Row():
261
+ input_file = gr.File(label="Carregar imagem ou vídeo")
262
  with gr.Column():
263
  seed = gr.Number(label="Seed", value=666)
264
  fps = gr.Number(label="FPS de Saída (para vídeo)", value=24)
 
273
 
274
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
275
 
 
276
  gr.Examples(
277
  examples=[
278
+ ["01.mp4", 4, 24],
279
+ ["02.mp4", 4, 24],
280
+ ["03.mp4", 4, 24],
281
  ],
282
  inputs=[input_file, seed, fps]
283
  )
284
 
285
  gr.HTML("""
286
  <hr>
287
+ <p>Se você achou o SeedVR útil, por favor ⭐ o
288
+ <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>:</p>
289
+ <a href="https://github.com/ByteDance-Seed/SeedVR" target="_blank">
290
+ <img src="https://img.shields.io/github/stars/ByteDance-Seed/SeedVR?style=social" alt="GitHub Stars">
291
+ </a>
292
  <h4>Aviso</h4>
293
+ <p>Esta demonstração suporta até <b>720p e 121 frames para vídeos ou imagens 2k</b>.
294
+ Para outros casos de uso, verifique o <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>repositório no GitHub</a>.</p>
295
  <h4>Limitações</h4>
296
  <p>Pode falhar em degradações pesadas ou em clipes AIGC com pouco movimento, causando excesso de nitidez ou restauração inadequada.</p>
297
  """)