Spaces:
Running
Running
File size: 5,877 Bytes
8ee2f06 2667dea 6bfaaa7 2667dea d4386a4 eecde3e 2588a29 8ee2f06 2588a29 3d5bbef 2588a29 8ee2f06 07ebdd7 8ee2f06 07ebdd7 3d5bbef 07ebdd7 2588a29 8ee2f06 2588a29 07ebdd7 3d5bbef 07ebdd7 2588a29 67137e5 2588a29 67137e5 07ebdd7 5da6833 07ebdd7 3d5bbef 67137e5 07ebdd7 67137e5 07ebdd7 2588a29 67137e5 3d5bbef 5da6833 2588a29 5da6833 2588a29 07ebdd7 2588a29 07ebdd7 2588a29 07ebdd7 3d5bbef 8ee2f06 3d5bbef 07ebdd7 c26a69c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# app.py (VERSÃO FINAL COM LAZY IMPORT)
import gradio as gr
import os
import uuid
import shutil
import subprocess
import mimetypes
from pathlib import Path
# A importação do torch.hub foi REMOVIDA daqui
# --- BLOCO DE CONFIGURAÇÃO E DOWNLOAD DE MODELO ---
APP_DIR = "/app"; SEEDVR_DIR = os.path.join(APP_DIR, "SeedVR")
MODEL_CACHE_DIR = "/tmp/models"; CKPTS_DIR = os.path.join(MODEL_CACHE_DIR, "ckpts")
os.makedirs(CKPTS_DIR, exist_ok=True)
print("Verificando e baixando modelos para /tmp/models/ckpts...")
# !!! A MUDANÇA CRUCIAL ESTÁ AQUI !!!
# A importação é feita apenas quando necessária.
from torch.hub import download_url_to_file
files_to_download = {
"seedvr2_ema_3b.pth": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth",
"ema_vae.pth": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth",
"pos_emb.pt": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt",
"neg_emb.pt": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt",
}
for filename, url in files_to_download.items():
destination_path = os.path.join(CKPTS_DIR, filename)
if not os.path.exists(destination_path):
print(f"Baixando {filename}..."); download_url_to_file(url, destination_path)
else: print(f"{filename} já existe.")
print("Verificação de modelos concluída.")
# --------------------------------------------------------------------
# O resto do código (run_inference, UI do Gradio) permanece o mesmo.
def run_inference(video_path, seed, res_h, res_w):
if video_path is None: raise gr.Error("Por favor, faça o upload de um arquivo.")
job_id = str(uuid.uuid4()); input_dir = os.path.join("/tmp", "temp_inputs", job_id); output_dir = os.path.join("/tmp", "temp_outputs", job_id)
os.makedirs(input_dir, exist_ok=True); os.makedirs(output_dir, exist_ok=True)
shutil.copy(video_path, input_dir)
log_output = ""
patched_script_path = os.path.join("/tmp", f"inference_patched_{job_id}.py")
try:
original_script_path = os.path.join(SEEDVR_DIR, "projects", "inference_seedvr2_3b.py")
with open(original_script_path, 'r') as f: script_content = f.read()
script_content = script_content.replace("'./ckpts/seedvr2_ema_3b.pth'", f"'{os.path.join(CKPTS_DIR, 'seedvr2_ema_3b.pth')}'")
script_content = script_content.replace("runner.configure_vae_model()", f"runner.configure_vae_model(checkpoint_path='{os.path.join(CKPTS_DIR, 'ema_vae.pth')}')")
script_content = script_content.replace("'pos_emb.pt'", f"'{os.path.join(CKPTS_DIR, 'pos_emb.pt')}'")
script_content = script_content.replace("'neg_emb.pt'", f"'{os.path.join(CKPTS_DIR, 'neg_emb.pt')}'")
with open(patched_script_path, 'w') as f: f.write(script_content)
input_folder_relative = os.path.relpath(input_dir, SEEDVR_DIR); output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR)
patched_script_relative_path = os.path.relpath(patched_script_path, SEEDVR_DIR)
command = ["torchrun", "--nproc-per-node=4", patched_script_relative_path, "--video_path", input_folder_relative, "--output_dir", output_folder_relative, "--seed", str(seed), "--res_h", str(res_h), "--res_w", str(res_w)]
env = os.environ.copy(); env["PYTHONUNBUFFERED"] = "1"
log_output += f"Executando comando: {' '.join(command)}\n\n"
yield None, None, log_output
process = subprocess.Popen(command, cwd=SEEDVR_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding='utf-8', env=env)
while True:
output = process.stdout.readline()
if output == '' and process.poll() is not None: break
if output: log_output += output; yield None, None, log_output
if process.poll() != 0: raise gr.Error(f"A inferência falhou.")
output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png', '.jpg', '.jpeg'))]
if not output_files: raise gr.Error("Nenhum arquivo de saída foi encontrado.")
result_path = os.path.join(output_dir, output_files[0])
media_type, _ = mimetypes.guess_type(result_path)
if media_type and media_type.startswith("image"): yield result_path, None, log_output
else: yield None, result_path, log_output
finally:
shutil.rmtree(input_dir, ignore_errors=True)
if os.path.exists(patched_script_path): os.remove(patched_script_path)
with gr.Blocks(css="footer {display: none !important}") as demo:
gr.Markdown("# 🚀 Interface de Inferência para SeedVR2")
gr.Markdown("Faça o upload de um vídeo ou imagem, ajuste os parâmetros e clique em 'Executar'.")
with gr.Row():
with gr.Column(scale=1):
input_media = gr.Video(label="Upload de Vídeo ou Imagem")
seed = gr.Number(value=666, label="Seed")
with gr.Accordion("Configurações Avançadas", open=False):
res_h = gr.Number(value=720, label="Altura da Saída (res_h)")
res_w = gr.Number(value=1280, label="Largura da Saída (res_w)")
run_button = gr.Button("Executar", variant="primary")
with gr.Column(scale=2):
output_image = gr.Image(label="Saída de Imagem")
output_video = gr.Video(label="Saída de Vídeo")
log_box = gr.Textbox(label="Logs em Tempo Real", lines=15, autoscroll=True, interactive=False)
run_button.click(
fn=run_inference,
inputs=[input_media, seed, res_h, res_w],
outputs=[output_image, output_video, log_box]
)
gr.Examples(
examples=[
["./SeedVR/01.mp4", 666],
["./SeedVR/02.mp4", 123],
["./SeedVR/03.mp4", 42],
],
inputs=[input_media, seed]
)
demo.queue(max_size=10).launch()
|