Aduc-sdr commited on
Commit
4891df0
·
verified ·
1 Parent(s): 0ea462c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -30
app.py CHANGED
@@ -32,12 +32,10 @@ print(f"Diretório atual adicionado ao sys.path.")
32
  # --- ETAPA 3: Instalar Dependências Corretamente ---
33
  python_executable = sys.executable
34
 
35
- # CORREÇÃO: Forçar uma versão do NumPy < 2.0 para evitar conflitos de compatibilidade.
36
  print("Instalando NumPy compatível...")
37
  subprocess.run([python_executable, "-m", "pip", "install", "numpy<2.0"], check=True)
38
 
39
- # Filtrar requirements.txt para evitar conflitos com torch/torchvision pré-instalados
40
- print("Filtrando requirements.txt...")
41
  with open("requirements.txt", "r") as f_in, open("filtered_requirements.txt", "w") as f_out:
42
  for line in f_in:
43
  if not line.strip().startswith(('torch', 'torchvision')):
@@ -52,6 +50,7 @@ subprocess.run([python_executable, "-m", "pip", "install", "flash-attn==2.5.9.po
52
  from pathlib import Path
53
  from urllib.parse import urlparse
54
  from torch.hub import download_url_to_file, get_dir
 
55
 
56
  def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
57
  os.makedirs(model_dir, exist_ok=True)
@@ -72,14 +71,11 @@ print("✅ Configuração do Apex concluída.")
72
 
73
  # --- ETAPA 4: Baixar os Modelos Pré-treinados ---
74
  print("Baixando modelos pré-treinados...")
75
- import torch
76
-
77
  pretrain_model_url = {
78
- 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/ema_vae.pth',
79
- 'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR-7B/resolve/main/seedvr_ema_7b.pth',
80
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
81
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
82
- #'apex': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/apex-0.1-cp39-cp39-linux_x86_64.whl'
83
  }
84
 
85
  Path('./ckpts').mkdir(exist_ok=True)
@@ -87,8 +83,12 @@ for key, url in pretrain_model_url.items():
87
  model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
88
  load_file_from_url(url=url, model_dir=model_dir)
89
 
 
 
 
90
 
91
- # --- ETAPA 5: Executar a Aplicação Principal ---
 
92
  import mediapy
93
  from einops import rearrange
94
  from omegaconf import OmegaConf
@@ -122,16 +122,24 @@ if use_colorfix:
122
  from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
123
 
124
  def configure_runner():
125
- config = load_config('configs_7b/main.yaml')
126
  runner = VideoDiffusionInfer(config)
127
  OmegaConf.set_readonly(runner.config, False)
 
128
  init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
129
- runner.configure_dit_model(device="cuda", checkpoint='ckpts/seedvr2_ema_7b.pth')
130
  runner.configure_vae_model()
131
  if hasattr(runner.vae, "set_memory_limit"):
132
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
133
  return runner
134
 
 
 
 
 
 
 
 
135
  def generation_step(runner, text_embeds_dict, cond_latents):
136
  def _move_to_cuda(x): return [i.to("cuda") for i in x]
137
  noises, aug_noises = [torch.randn_like(l) for l in cond_latents], [torch.randn_like(l) for l in cond_latents]
@@ -147,11 +155,27 @@ def generation_step(runner, text_embeds_dict, cond_latents):
147
  video_tensors = runner.inference(noises=noises, conditions=conditions, **text_embeds_dict)
148
  return [rearrange(v, "c t h w -> t c h w") for v in video_tensors]
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  @spaces.GPU
151
  def generation_loop(video_path, seed=666, fps_out=24):
152
  if video_path is None: return None, None, None
153
- runner = configure_runner()
154
- # Adicionado `weights_only=True` para segurança e para suprimir o aviso
 
155
  text_embeds = {
156
  "texts_pos": [torch.load('pos_emb.pt', weights_only=True).to("cuda")],
157
  "texts_neg": [torch.load('neg_emb.pt', weights_only=True).to("cuda")]
@@ -160,31 +184,36 @@ def generation_loop(video_path, seed=666, fps_out=24):
160
  set_seed(int(seed))
161
  os.makedirs("output", exist_ok=True)
162
 
163
- # CORREÇÃO: Fornecer os argumentos que faltam para NaResize.
164
  res_h, res_w = 1280, 720
165
  transform = Compose([
166
  NaResize(resolution=(res_h * res_w)**0.5, mode="area", downsample_only=False),
167
  Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
168
- DivisibleCrop((16, 16)),
169
- Normalize(0.5, 0.5),
170
- Rearrange("t c h w -> c t h w")
171
  ])
172
 
173
  media_type, _ = mimetypes.guess_type(video_path)
174
  is_video = media_type and media_type.startswith("video")
175
 
176
  if is_video:
177
- video, _, _ = read_video(video_path, output_format="TCHW")
178
- video = video[:121] / 255.0
179
  output_path = os.path.join("output", f"{uuid.uuid4()}.mp4")
180
  else:
181
  video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
182
  output_path = os.path.join("output", f"{uuid.uuid4()}.png")
183
 
184
- cond_latents = [transform(video.to("cuda"))]
185
- ori_length = cond_latents[0].size(2)
 
 
 
 
 
 
 
186
  cond_latents = runner.vae_encode(cond_latents)
187
  samples = generation_step(runner, text_embeds, cond_latents)
 
188
  sample = samples[0][:ori_length].cpu()
189
  sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
190
 
@@ -196,14 +225,7 @@ def generation_loop(video_path, seed=666, fps_out=24):
196
  return output_path, None, output_path
197
 
198
  with gr.Blocks(title="SeedVR") as demo:
199
- gr.HTML(f"""
200
-
201
- <p><b>Demonstração oficial do Gradio</b> para
202
- <a href='https://github.com/ByteDance-Seed/SeedVR' target='_blank'>
203
- <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
204
- 🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.
205
- </p>
206
- """)
207
  with gr.Row():
208
  input_file = gr.File(label="Carregar Imagem ou Vídeo")
209
  with gr.Column():
@@ -214,6 +236,7 @@ with gr.Blocks(title="SeedVR") as demo:
214
  output_video = gr.Video(label="Vídeo de Saída")
215
  download_link = gr.File(label="Baixar Resultado")
216
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
217
-
 
218
 
219
  demo.queue().launch(share=True)
 
32
  # --- ETAPA 3: Instalar Dependências Corretamente ---
33
  python_executable = sys.executable
34
 
 
35
  print("Instalando NumPy compatível...")
36
  subprocess.run([python_executable, "-m", "pip", "install", "numpy<2.0"], check=True)
37
 
38
+ print("Filtrando requirements.txt para evitar conflitos de versão...")
 
39
  with open("requirements.txt", "r") as f_in, open("filtered_requirements.txt", "w") as f_out:
40
  for line in f_in:
41
  if not line.strip().startswith(('torch', 'torchvision')):
 
50
  from pathlib import Path
51
  from urllib.parse import urlparse
52
  from torch.hub import download_url_to_file, get_dir
53
+ import torch
54
 
55
  def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
56
  os.makedirs(model_dir, exist_ok=True)
 
71
 
72
  # --- ETAPA 4: Baixar os Modelos Pré-treinados ---
73
  print("Baixando modelos pré-treinados...")
 
 
74
  pretrain_model_url = {
75
+ 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
76
+ 'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
77
  'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
78
  'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
 
79
  }
80
 
81
  Path('./ckpts').mkdir(exist_ok=True)
 
83
  model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
84
  load_file_from_url(url=url, model_dir=model_dir)
85
 
86
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4', '01.mp4')
87
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4', '02.mp4')
88
+ torch.hub.download_url_to_file('https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4', '03.mp4')
89
 
90
+ # --- ETAPA 5: Inicialização Global do Modelo (FEITA APENAS UMA VEZ) ---
91
+ print("Inicializando o modelo e o ambiente distribuído (uma única vez)...")
92
  import mediapy
93
  from einops import rearrange
94
  from omegaconf import OmegaConf
 
122
  from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
123
 
124
  def configure_runner():
125
+ config = load_config('configs_3b/main.yaml')
126
  runner = VideoDiffusionInfer(config)
127
  OmegaConf.set_readonly(runner.config, False)
128
+ # A chamada de inicialização crítica é feita aqui
129
  init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
130
+ runner.configure_dit_model(device="cuda", checkpoint='ckpts/seedvr2_ema_3b.pth')
131
  runner.configure_vae_model()
132
  if hasattr(runner.vae, "set_memory_limit"):
133
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
134
  return runner
135
 
136
+ # Criamos o runner globalmente, UMA ÚNICA VEZ
137
+ GLOBAL_RUNNER = configure_runner()
138
+ print("✅ Setup completo. Aplicação pronta para receber requisições.")
139
+
140
+
141
+ # --- ETAPA 6: Funções de Inferência e UI do Gradio ---
142
+
143
  def generation_step(runner, text_embeds_dict, cond_latents):
144
  def _move_to_cuda(x): return [i.to("cuda") for i in x]
145
  noises, aug_noises = [torch.randn_like(l) for l in cond_latents], [torch.randn_like(l) for l in cond_latents]
 
155
  video_tensors = runner.inference(noises=noises, conditions=conditions, **text_embeds_dict)
156
  return [rearrange(v, "c t h w -> t c h w") for v in video_tensors]
157
 
158
+ def cut_videos(videos, sp_size=1):
159
+ t = videos.size(1)
160
+ if t > 121:
161
+ videos = videos[:, :121]
162
+ t = 121
163
+ if (t - 1) % (4 * sp_size) == 0:
164
+ return videos
165
+ else:
166
+ padding_needed = 4 * sp_size - ((t - 1) % (4 * sp_size))
167
+ last_frame = videos[:, -1].unsqueeze(1)
168
+ padding = last_frame.repeat(1, padding_needed, 1, 1)
169
+ videos = torch.cat([videos, padding], dim=1)
170
+ assert (videos.size(1) - 1) % (4 * sp_size) == 0
171
+ return videos
172
+
173
  @spaces.GPU
174
  def generation_loop(video_path, seed=666, fps_out=24):
175
  if video_path is None: return None, None, None
176
+ # CORREÇÃO: Usamos o runner global em vez de criar um novo
177
+ runner = GLOBAL_RUNNER
178
+
179
  text_embeds = {
180
  "texts_pos": [torch.load('pos_emb.pt', weights_only=True).to("cuda")],
181
  "texts_neg": [torch.load('neg_emb.pt', weights_only=True).to("cuda")]
 
184
  set_seed(int(seed))
185
  os.makedirs("output", exist_ok=True)
186
 
 
187
  res_h, res_w = 1280, 720
188
  transform = Compose([
189
  NaResize(resolution=(res_h * res_w)**0.5, mode="area", downsample_only=False),
190
  Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
191
+ DivisibleCrop((16, 16)), Normalize(0.5, 0.5), Rearrange("t c h w -> c t h w")
 
 
192
  ])
193
 
194
  media_type, _ = mimetypes.guess_type(video_path)
195
  is_video = media_type and media_type.startswith("video")
196
 
197
  if is_video:
198
+ video, _, _ = read_video(video_path, output_format="TCHW", pts_unit="sec")
199
+ video = video / 255.0
200
  output_path = os.path.join("output", f"{uuid.uuid4()}.mp4")
201
  else:
202
  video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
203
  output_path = os.path.join("output", f"{uuid.uuid4()}.png")
204
 
205
+ transformed_video = transform(video.to("cuda"))
206
+ ori_length = transformed_video.size(1)
207
+
208
+ if is_video:
209
+ padded_video = cut_videos(transformed_video)
210
+ cond_latents = [padded_video]
211
+ else:
212
+ cond_latents = [transformed_video]
213
+
214
  cond_latents = runner.vae_encode(cond_latents)
215
  samples = generation_step(runner, text_embeds, cond_latents)
216
+
217
  sample = samples[0][:ori_length].cpu()
218
  sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
219
 
 
225
  return output_path, None, output_path
226
 
227
  with gr.Blocks(title="SeedVR") as demo:
228
+ gr.HTML(f"""<div style='text-align:center; margin-bottom: 10px;'><img src='file/{os.path.abspath("assets/seedvr_logo.png")}' style='height:40px;'/></div>...""")
 
 
 
 
 
 
 
229
  with gr.Row():
230
  input_file = gr.File(label="Carregar Imagem ou Vídeo")
231
  with gr.Column():
 
236
  output_video = gr.Video(label="Vídeo de Saída")
237
  download_link = gr.File(label="Baixar Resultado")
238
  run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
239
+ gr.Examples(examples=[["01.mp4", 42, 24], ["02.mp4", 42, 24], ["03.mp4", 42, 24]], inputs=[input_file, seed, fps])
240
+ gr.HTML("""<hr>...""")
241
 
242
  demo.queue().launch(share=True)