File size: 5,530 Bytes
8ee2f06
2667dea
6bfaaa7
2667dea
d4386a4
eecde3e
2588a29
 
 
8ee2f06
2588a29
3d5bbef
 
 
2588a29
8ee2f06
 
 
 
 
07ebdd7
 
 
 
 
 
 
 
 
3d5bbef
 
07ebdd7
22e1c46
 
2588a29
07ebdd7
3d5bbef
07ebdd7
2588a29
 
67137e5
2588a29
67137e5
07ebdd7
5da6833
07ebdd7
 
 
 
3d5bbef
67137e5
07ebdd7
 
 
67137e5
07ebdd7
2588a29
 
67137e5
3d5bbef
5da6833
2588a29
5da6833
2588a29
 
07ebdd7
 
2588a29
 
07ebdd7
2588a29
07ebdd7
3d5bbef
 
 
 
 
 
 
 
 
 
 
 
 
22e1c46
bf86b52
22e1c46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# app.py (VERSÃO FINAL COM LAZY IMPORT)

import gradio as gr
import os
import uuid
import shutil
import subprocess
import mimetypes
from pathlib import Path
# A importação do torch.hub foi REMOVIDA daqui

# --- BLOCO DE CONFIGURAÇÃO E DOWNLOAD DE MODELO ---
APP_DIR = "/app"; SEEDVR_DIR = os.path.join(APP_DIR, "SeedVR")
MODEL_CACHE_DIR = "/tmp/models"; CKPTS_DIR = os.path.join(MODEL_CACHE_DIR, "ckpts")
os.makedirs(CKPTS_DIR, exist_ok=True)

print("Verificando e baixando modelos para /tmp/models/ckpts...")
# A importação é feita apenas quando necessária.
from torch.hub import download_url_to_file

files_to_download = {
    "seedvr2_ema_3b.pth": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth",
    "ema_vae.pth": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth",
    "pos_emb.pt": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt",
    "neg_emb.pt": "https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt",
}
for filename, url in files_to_download.items():
    destination_path = os.path.join(CKPTS_DIR, filename)
    if not os.path.exists(destination_path):
        print(f"Baixando {filename}..."); download_url_to_file(url, destination_path)
    else: print(f"{filename} já existe.")
print("Verificação de modelos concluída.")
# --- O resto do código permanece o mesmo ---
# (Cole o resto do seu app.py funcional aqui)
def run_inference(video_path, seed, res_h, res_w):
    if video_path is None: raise gr.Error("Por favor, faça o upload de um arquivo.")
    job_id = str(uuid.uuid4()); input_dir = os.path.join("/tmp", "temp_inputs", job_id); output_dir = os.path.join("/tmp", "temp_outputs", job_id)
    os.makedirs(input_dir, exist_ok=True); os.makedirs(output_dir, exist_ok=True)
    shutil.copy(video_path, input_dir)
    log_output = ""
    patched_script_path = os.path.join("/tmp", f"inference_patched_{job_id}.py")
    try:
        original_script_path = os.path.join(SEEDVR_DIR, "projects", "inference_seedvr2_3b.py")
        with open(original_script_path, 'r') as f: script_content = f.read()
        script_content = script_content.replace("'./ckpts/seedvr2_ema_3b.pth'", f"'{os.path.join(CKPTS_DIR, 'seedvr2_ema_3b.pth')}'")
        script_content = script_content.replace("runner.configure_vae_model()", f"runner.configure_vae_model(checkpoint_path='{os.path.join(CKPTS_DIR, 'ema_vae.pth')}')")
        script_content = script_content.replace("'pos_emb.pt'", f"'{os.path.join(CKPTS_DIR, 'pos_emb.pt')}'")
        script_content = script_content.replace("'neg_emb.pt'", f"'{os.path.join(CKPTS_DIR, 'neg_emb.pt')}'")
        with open(patched_script_path, 'w') as f: f.write(script_content)
        input_folder_relative = os.path.relpath(input_dir, SEEDVR_DIR); output_folder_relative = os.path.relpath(output_dir, SEEDVR_DIR)
        patched_script_relative_path = os.path.relpath(patched_script_path, SEEDVR_DIR)
        command = ["torchrun", "--nproc-per-node=4", patched_script_relative_path, "--video_path", input_folder_relative, "--output_dir", output_folder_relative, "--seed", str(seed), "--res_h", str(res_h), "--res_w", str(res_w)]
        env = os.environ.copy(); env["PYTHONUNBUFFERED"] = "1"
        log_output += f"Executando comando: {' '.join(command)}\n\n"
        yield None, None, log_output
        process = subprocess.Popen(command, cwd=SEEDVR_DIR, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, encoding='utf-8', env=env)
        while True:
            output = process.stdout.readline()
            if output == '' and process.poll() is not None: break
            if output: log_output += output; yield None, None, log_output
        if process.poll() != 0: raise gr.Error(f"A inferência falhou.")
        output_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.png', '.jpg', '.jpeg'))]
        if not output_files: raise gr.Error("Nenhum arquivo de saída foi encontrado.")
        result_path = os.path.join(output_dir, output_files[0])
        media_type, _ = mimetypes.guess_type(result_path)
        if media_type and media_type.startswith("image"): yield result_path, None, log_output
        else: yield None, result_path, log_output
    finally:
        shutil.rmtree(input_dir, ignore_errors=True)
        if os.path.exists(patched_script_path): os.remove(patched_script_path)
with gr.Blocks(css="footer {display: none !important}") as demo:
    gr.Markdown("# 🚀 Interface de Inferência para SeedVR2")
    gr.Markdown("Faça o upload de um vídeo ou imagem, ajuste os parâmetros e clique em 'Executar'.")
    with gr.Row():
        with gr.Column(scale=1):
            input_media = gr.Video(label="Upload de Vídeo ou Imagem")
            seed = gr.Number(value=666, label="Seed")
            with gr.Accordion("Configurações Avançadas", open=False):
                res_h = gr.Number(value=720, label="Altura da Saída (res_h)")
                res_w = gr.Number(value=1280, label="Largura da Saída (res_w)")
            run_button = gr.Button("Executar", variant="primary")
        with gr.Column(scale=2):
            output_image = gr.Image(label="Saída de Imagem")
            output_video = gr.Video(label="Saída de Vídeo")
            log_box = gr.Textbox(label="Logs em Tempo Real", lines=15, autoscroll=True, interactive=False)
    run_button.click(fn=run_inference, inputs=[input_media, seed, res_h, res_w], outputs=[output_image, output_video, log_box])
    
demo.queue(max_size=10).launch()