Carlex22222 commited on
Commit
3d5bbef
·
verified ·
1 Parent(s): 07ebdd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -46
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py (VERSÃO FINAL COM DOWNLOAD DIRETO VIA TORCH.HUB)
2
 
3
  import gradio as gr
4
  import os
@@ -7,101 +7,98 @@ import shutil
7
  import subprocess
8
  import mimetypes
9
  from pathlib import Path
10
- from torch.hub import download_url_to_file # <-- MUDANÇA: Usamos a função de download do PyTorch
11
-
12
- # --- BLOCO DE CONFIGURAÇÃO E DOWNLOAD DE MODELO CORRIGIDO ---
13
- APP_DIR = "/app"
14
- SEEDVR_DIR = os.path.join(APP_DIR, "SeedVR")
15
- MODEL_CACHE_DIR = "/tmp/models"
16
- CKPTS_DIR = os.path.join(MODEL_CACHE_DIR, "ckpts")
17
 
 
 
 
18
  os.makedirs(CKPTS_DIR, exist_ok=True)
19
-
20
- # Dicionário com os links diretos para os arquivos do modelo
21
  files_to_download = {
22
  "seedvr2_ema_3b.pth": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth",
23
  "ema_vae.pth": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth",
24
  "pos_emb.pt": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt",
25
  "neg_emb.pt": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt",
26
  }
27
-
28
  print("Verificando e baixando modelos para /tmp/models/ckpts...")
29
  for filename, url in files_to_download.items():
30
  destination_path = os.path.join(CKPTS_DIR, filename)
31
  if not os.path.exists(destination_path):
32
- print(f"Baixando {filename}...")
33
- download_url_to_file(url, destination_path)
34
- print(f"{filename} baixado com sucesso.")
35
- else:
36
- print(f"{filename} já existe. Pulando o download.")
37
  print("Verificação de modelos concluída.")
38
  # --------------------------------------------------------------------
39
 
40
  def run_inference(video_path, seed, res_h, res_w):
41
- # ... (O resto do código permanece o mesmo) ...
42
  if video_path is None: raise gr.Error("Por favor, faça o upload de um arquivo.")
43
-
44
- job_id = str(uuid.uuid4())
45
- input_dir = os.path.join("/tmp", "temp_inputs", job_id)
46
- output_dir = os.path.join("/tmp", "temp_outputs", job_id)
47
  os.makedirs(input_dir, exist_ok=True); os.makedirs(output_dir, exist_ok=True)
48
-
49
  shutil.copy(video_path, input_dir)
50
-
51
  log_output = ""
52
-
53
  patched_script_path = os.path.join("/tmp", f"inference_patched_{job_id}.py")
54
  try:
55
  original_script_path = os.path.join(SEEDVR_DIR, "projects", "inference_seedvr2_3b.py")
56
  with open(original_script_path, 'r') as f: script_content = f.read()
57
-
58
  script_content = script_content.replace("'./ckpts/seedvr2_ema_3b.pth'", f"'{os.path.join(CKPTS_DIR, 'seedvr2_ema_3b.pth')}'")
59
  script_content = script_content.replace("runner.configure_vae_model()", f"runner.configure_vae_model(checkpoint_path='{os.path.join(CKPTS_DIR, 'ema_vae.pth')}')")
60
  script_content = script_content.replace("'pos_emb.pt'", f"'{os.path.join(CKPTS_DIR, 'pos_emb.pt')}'")
61
  script_content = script_content.replace("'neg_emb.pt'", f"'{os.path.join(CKPTS_DIR, 'neg_emb.pt')}'")
62
-
63
  with open(patched_script_path, 'w') as f: f.write(script_content)
64
-
65
- input_folder_relative = os.path.relpath(input_dir, SEEDVR_DIR)
66
- output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR)
67
  patched_script_relative_path = os.path.relpath(patched_script_path, SEEDVR_DIR)
68
-
69
  command = ["torchrun", "--nproc-per-node=4", patched_script_relative_path, "--video_path", input_folder_relative, "--output_dir", output_folder_relative, "--seed", str(seed), "--res_h", str(res_h), "--res_w", str(res_w)]
70
-
71
  env = os.environ.copy(); env["PYTHONUNBUFFERED"] = "1"
72
-
73
  log_output += f"Executando comando: {' '.join(command)}\n\n"
74
  yield None, None, log_output
75
-
76
  process = subprocess.Popen(command, cwd=SEEDVR_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding='utf-8', env=env)
77
-
78
  while True:
79
  output = process.stdout.readline()
80
  if output == '' and process.poll() is not None: break
81
- if output:
82
- log_output += output
83
- yield None, None, log_output
84
-
85
  if process.poll() != 0: raise gr.Error(f"A inferência falhou.")
86
-
87
  output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png', '.jpg', '.jpeg'))]
88
  if not output_files: raise gr.Error("Nenhum arquivo de saída foi encontrado.")
89
-
90
  result_path = os.path.join(output_dir, output_files[0])
91
-
92
  media_type, _ = mimetypes.guess_type(result_path)
93
  if media_type and media_type.startswith("image"): yield result_path, None, log_output
94
  else: yield None, result_path, log_output
95
-
96
  finally:
97
  shutil.rmtree(input_dir, ignore_errors=True)
98
  if os.path.exists(patched_script_path): os.remove(patched_script_path)
99
 
100
- # --- Interface Gráfica Gradio (sem alterações) ---
101
  with gr.Blocks(css="footer {display: none !important}") as demo:
102
  gr.Markdown("# 🚀 Interface de Inferência para SeedVR2")
103
- # ... (resto da UI)
104
- run_button.click(fn=run_inference, inputs=[input_media, seed, res_h, res_w], outputs=[output_image, output_video, log_box])
105
- gr.Examples(examples=[["./SeedVR/01.mp4", 666],["./SeedVR/02.mp4", 123],["./SeedVR/03.mp4", 42]], inputs=[input_media, seed])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  demo.queue(max_size=10).launch()
 
1
+ # app.py (VERSÃO FINAL E CORRIGIDA)
2
 
3
  import gradio as gr
4
  import os
 
7
  import subprocess
8
  import mimetypes
9
  from pathlib import Path
10
+ from torch.hub import download_url_to_file
 
 
 
 
 
 
11
 
12
+ # --- BLOCO DE CONFIGURAÇÃO E DOWNLOAD DE MODELO ---
13
+ APP_DIR = "/app"; SEEDVR_DIR = os.path.join(APP_DIR, "SeedVR")
14
+ MODEL_CACHE_DIR = "/tmp/models"; CKPTS_DIR = os.path.join(MODEL_CACHE_DIR, "ckpts")
15
  os.makedirs(CKPTS_DIR, exist_ok=True)
 
 
16
  files_to_download = {
17
  "seedvr2_ema_3b.pth": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth",
18
  "ema_vae.pth": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth",
19
  "pos_emb.pt": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt",
20
  "neg_emb.pt": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt",
21
  }
 
22
  print("Verificando e baixando modelos para /tmp/models/ckpts...")
23
  for filename, url in files_to_download.items():
24
  destination_path = os.path.join(CKPTS_DIR, filename)
25
  if not os.path.exists(destination_path):
26
+ print(f"Baixando {filename}..."); download_url_to_file(url, destination_path)
27
+ else: print(f"{filename} já existe.")
 
 
 
28
  print("Verificação de modelos concluída.")
29
  # --------------------------------------------------------------------
30
 
31
  def run_inference(video_path, seed, res_h, res_w):
 
32
  if video_path is None: raise gr.Error("Por favor, faça o upload de um arquivo.")
33
+ job_id = str(uuid.uuid4()); input_dir = os.path.join("/tmp", "temp_inputs", job_id); output_dir = os.path.join("/tmp", "temp_outputs", job_id)
 
 
 
34
  os.makedirs(input_dir, exist_ok=True); os.makedirs(output_dir, exist_ok=True)
 
35
  shutil.copy(video_path, input_dir)
 
36
  log_output = ""
 
37
  patched_script_path = os.path.join("/tmp", f"inference_patched_{job_id}.py")
38
  try:
39
  original_script_path = os.path.join(SEEDVR_DIR, "projects", "inference_seedvr2_3b.py")
40
  with open(original_script_path, 'r') as f: script_content = f.read()
 
41
  script_content = script_content.replace("'./ckpts/seedvr2_ema_3b.pth'", f"'{os.path.join(CKPTS_DIR, 'seedvr2_ema_3b.pth')}'")
42
  script_content = script_content.replace("runner.configure_vae_model()", f"runner.configure_vae_model(checkpoint_path='{os.path.join(CKPTS_DIR, 'ema_vae.pth')}')")
43
  script_content = script_content.replace("'pos_emb.pt'", f"'{os.path.join(CKPTS_DIR, 'pos_emb.pt')}'")
44
  script_content = script_content.replace("'neg_emb.pt'", f"'{os.path.join(CKPTS_DIR, 'neg_emb.pt')}'")
 
45
  with open(patched_script_path, 'w') as f: f.write(script_content)
46
+ input_folder_relative = os.path.relpath(input_dir, SEEDVR_DIR); output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR)
 
 
47
  patched_script_relative_path = os.path.relpath(patched_script_path, SEEDVR_DIR)
 
48
  command = ["torchrun", "--nproc-per-node=4", patched_script_relative_path, "--video_path", input_folder_relative, "--output_dir", output_folder_relative, "--seed", str(seed), "--res_h", str(res_h), "--res_w", str(res_w)]
 
49
  env = os.environ.copy(); env["PYTHONUNBUFFERED"] = "1"
 
50
  log_output += f"Executando comando: {' '.join(command)}\n\n"
51
  yield None, None, log_output
 
52
  process = subprocess.Popen(command, cwd=SEEDVR_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding='utf-8', env=env)
 
53
  while True:
54
  output = process.stdout.readline()
55
  if output == '' and process.poll() is not None: break
56
+ if output: log_output += output; yield None, None, log_output
 
 
 
57
  if process.poll() != 0: raise gr.Error(f"A inferência falhou.")
 
58
  output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png', '.jpg', '.jpeg'))]
59
  if not output_files: raise gr.Error("Nenhum arquivo de saída foi encontrado.")
 
60
  result_path = os.path.join(output_dir, output_files[0])
 
61
  media_type, _ = mimetypes.guess_type(result_path)
62
  if media_type and media_type.startswith("image"): yield result_path, None, log_output
63
  else: yield None, result_path, log_output
 
64
  finally:
65
  shutil.rmtree(input_dir, ignore_errors=True)
66
  if os.path.exists(patched_script_path): os.remove(patched_script_path)
67
 
68
+ # --- Interface Gráfica Gradio ---
69
  with gr.Blocks(css="footer {display: none !important}") as demo:
70
  gr.Markdown("# 🚀 Interface de Inferência para SeedVR2")
71
+ gr.Markdown("Faça o upload de um vídeo ou imagem, ajuste os parâmetros e clique em 'Executar'.")
72
+
73
+ with gr.Row():
74
+ with gr.Column(scale=1):
75
+ input_media = gr.Video(label="Upload de Vídeo ou Imagem")
76
+ seed = gr.Number(value=666, label="Seed")
77
+ with gr.Accordion("Configurações Avançadas", open=False):
78
+ res_h = gr.Number(value=720, label="Altura da Saída (res_h)")
79
+ res_w = gr.Number(value=1280, label="Largura da Saída (res_w)")
80
+ run_button = gr.Button("Executar", variant="primary")
81
+
82
+ with gr.Column(scale=2):
83
+ output_image = gr.Image(label="Saída de Imagem")
84
+ output_video = gr.Video(label="Saída de Vídeo")
85
+ log_box = gr.Textbox(label="Logs em Tempo Real", lines=15, autoscroll=True, interactive=False)
86
+
87
+ # !!! A CORREÇÃO FINAL ESTÁ AQUI !!!
88
+ # Estas duas chamadas foram movidas para DENTRO do bloco 'with gr.Blocks()'.
89
+ run_button.click(
90
+ fn=run_inference,
91
+ inputs=[input_media, seed, res_h, res_w],
92
+ outputs=[output_image, output_video, log_box]
93
+ )
94
+
95
+ gr.Examples(
96
+ examples=[
97
+ ["./SeedVR/01.mp4", 666],
98
+ ["./SeedVR/02.mp4", 123],
99
+ ["./SeedVR/03.mp4", 42],
100
+ ],
101
+ inputs=[input_media, seed]
102
+ )
103
 
104
  demo.queue(max_size=10).launch()