Spaces:
Paused
Paused
Delete app.py
Browse files
app.py
DELETED
|
@@ -1,191 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import subprocess
|
| 4 |
-
import importlib.util
|
| 5 |
-
|
| 6 |
-
# --- ETAPA 0: Instalação Final do flash-attn ---
|
| 7 |
-
|
| 8 |
-
# Verifica se o flash_attn já está instalado. Se não, instala.
|
| 9 |
-
package_name = 'flash_attn'
|
| 10 |
-
spec = importlib.util.find_spec(package_name)
|
| 11 |
-
if spec is None:
|
| 12 |
-
print(f"Instalando o pacote que faltava: {package_name}. Isso pode levar um minuto...")
|
| 13 |
-
# Usamos o python executável do ambiente atual para instalar o pacote
|
| 14 |
-
python_executable = sys.executable
|
| 15 |
-
subprocess.run(
|
| 16 |
-
[
|
| 17 |
-
python_executable, "-m", "pip", "install",
|
| 18 |
-
"flash_attn==2.5.9.post1",
|
| 19 |
-
"--no-build-isolation"
|
| 20 |
-
],
|
| 21 |
-
check=True
|
| 22 |
-
)
|
| 23 |
-
print(f"✅ {package_name} instalado com sucesso.")
|
| 24 |
-
else:
|
| 25 |
-
print(f"✅ Pacote {package_name} já está instalado.")
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# A partir daqui, o ambiente está 100% pronto.
|
| 29 |
-
# ---------------------------------------------------------------------
|
| 30 |
-
|
| 31 |
-
import spaces
|
| 32 |
-
from pathlib import Path
|
| 33 |
-
from urllib.parse import urlparse
|
| 34 |
-
import torch
|
| 35 |
-
from torch.hub import download_url_to_file
|
| 36 |
-
import mediapy
|
| 37 |
-
from einops import rearrange
|
| 38 |
-
from omegaconf import OmegaConf
|
| 39 |
-
import datetime
|
| 40 |
-
import gc
|
| 41 |
-
from PIL import Image
|
| 42 |
-
import gradio as gr
|
| 43 |
-
import uuid
|
| 44 |
-
import mimetypes
|
| 45 |
-
import torchvision.transforms as T
|
| 46 |
-
from torchvision.transforms import Compose, Lambda, Normalize
|
| 47 |
-
from torchvision.io.video import read_video
|
| 48 |
-
|
| 49 |
-
# --- ETAPA 1: Clonar o Repositório e Mudar para o Diretório ---
|
| 50 |
-
repo_name = "SeedVR"
|
| 51 |
-
if not os.path.exists(repo_name):
|
| 52 |
-
print(f"Clonando o repositório {repo_name} do GitHub...")
|
| 53 |
-
subprocess.run(f"git clone https://github.com/ByteDance-Seed/{repo_name}.git", shell=True, check=True)
|
| 54 |
-
|
| 55 |
-
# Garante que estamos no diretório certo
|
| 56 |
-
if not os.getcwd().endswith(repo_name):
|
| 57 |
-
os.chdir(repo_name)
|
| 58 |
-
|
| 59 |
-
sys.path.insert(0, os.path.abspath('.'))
|
| 60 |
-
|
| 61 |
-
# Importações do projeto SeedVR (só podem ser feitas após o chdir)
|
| 62 |
-
from data.image.transforms.divisible_crop import DivisibleCrop
|
| 63 |
-
from data.image.transforms.na_resize import NaResize
|
| 64 |
-
from data.video.transforms.rearrange import Rearrange
|
| 65 |
-
from common.config import load_config
|
| 66 |
-
from common.distributed import init_torch
|
| 67 |
-
from common.seed import set_seed
|
| 68 |
-
from projects.video_diffusion_sr.infer import VideoDiffusionInfer
|
| 69 |
-
from common.distributed.ops import sync_data
|
| 70 |
-
|
| 71 |
-
print("Ambiente Conda carregado e verificado. Iniciando a aplicação...")
|
| 72 |
-
|
| 73 |
-
# --- ETAPA 2: Baixar os Modelos Pré-treinados ---
|
| 74 |
-
print("Baixando modelos pré-treinados...")
|
| 75 |
-
|
| 76 |
-
def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
|
| 77 |
-
os.makedirs(model_dir, exist_ok=True)
|
| 78 |
-
if not file_name:
|
| 79 |
-
parts = urlparse(url)
|
| 80 |
-
file_name = os.path.basename(parts.path)
|
| 81 |
-
cached_file = os.path.join(model_dir, file_name)
|
| 82 |
-
if not os.path.exists(cached_file):
|
| 83 |
-
print(f'Baixando: "{url}" para {cached_file}\n')
|
| 84 |
-
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
| 85 |
-
return cached_file
|
| 86 |
-
|
| 87 |
-
pretrain_model_url = {
|
| 88 |
-
'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
|
| 89 |
-
'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
|
| 90 |
-
'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
|
| 91 |
-
'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
|
| 92 |
-
}
|
| 93 |
-
|
| 94 |
-
Path('./ckpts').mkdir(exist_ok=True)
|
| 95 |
-
for key, url in pretrain_model_url.items():
|
| 96 |
-
model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
|
| 97 |
-
load_file_from_url(url=url, model_dir=model_dir)
|
| 98 |
-
|
| 99 |
-
# --- ETAPA 3: Executar a Aplicação Principal ---
|
| 100 |
-
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 101 |
-
os.environ["MASTER_PORT"] = "12355"
|
| 102 |
-
os.environ["RANK"] = str(0)
|
| 103 |
-
os.environ["WORLD_SIZE"] = str(1)
|
| 104 |
-
|
| 105 |
-
def configure_runner():
|
| 106 |
-
config = load_config('configs_3b/main.yaml')
|
| 107 |
-
runner = VideoDiffusionInfer(config)
|
| 108 |
-
OmegaConf.set_readonly(runner.config, False)
|
| 109 |
-
init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
|
| 110 |
-
runner.configure_dit_model(device="cuda", checkpoint='ckpts/seedvr2_ema_3b.pth')
|
| 111 |
-
runner.configure_vae_model()
|
| 112 |
-
if hasattr(runner.vae, "set_memory_limit"):
|
| 113 |
-
runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
|
| 114 |
-
return runner
|
| 115 |
-
|
| 116 |
-
def generation_step(runner, text_embeds_dict, cond_latents):
|
| 117 |
-
def _move_to_cuda(x): return [i.to("cuda") for i in x]
|
| 118 |
-
noises, aug_noises = [torch.randn_like(l) for l in cond_latents], [torch.randn_like(l) for l in cond_latents]
|
| 119 |
-
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
|
| 120 |
-
noises, aug_noises, cond_latents = map(_move_to_cuda, (noises, aug_noises, cond_latents))
|
| 121 |
-
def _add_noise(x, aug_noise):
|
| 122 |
-
t = torch.tensor([100.0], device="cuda")
|
| 123 |
-
shape = torch.tensor(x.shape[1:], device="cuda")[None]
|
| 124 |
-
t = runner.timestep_transform(t, shape)
|
| 125 |
-
return runner.schedule.forward(x, aug_noise, t)
|
| 126 |
-
conditions = [runner.get_condition(n, task="sr", latent_blur=_add_noise(l, an)) for n, an, l in zip(noises, aug_noises, cond_latents)]
|
| 127 |
-
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
|
| 128 |
-
video_tensors = runner.inference(noises=noises, conditions=conditions, **text_embeds_dict)
|
| 129 |
-
return [rearrange(v, "c t h w -> t c h w") for v in video_tensors]
|
| 130 |
-
|
| 131 |
-
@spaces.GPU
|
| 132 |
-
def generation_loop(video_path, seed=666, fps_out=24):
|
| 133 |
-
if video_path is None: return None, None, None
|
| 134 |
-
runner = configure_runner()
|
| 135 |
-
text_embeds = {
|
| 136 |
-
"texts_pos": [torch.load('pos_emb.pt', weights_only=True).to("cuda")],
|
| 137 |
-
"texts_neg": [torch.load('neg_emb.pt', weights_only=True).to("cuda")]
|
| 138 |
-
}
|
| 139 |
-
runner.configure_diffusion()
|
| 140 |
-
set_seed(int(seed))
|
| 141 |
-
os.makedirs("output", exist_ok=True)
|
| 142 |
-
res_h, res_w = 1280, 720
|
| 143 |
-
transform = Compose([
|
| 144 |
-
NaResize(resolution=(res_h * res_w)**0.5, mode="area", downsample_only=False),
|
| 145 |
-
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
|
| 146 |
-
DivisibleCrop((16, 16)),
|
| 147 |
-
Normalize(0.5, 0.5),
|
| 148 |
-
Rearrange("t c h w -> c t h w")
|
| 149 |
-
])
|
| 150 |
-
media_type, _ = mimetypes.guess_type(video_path)
|
| 151 |
-
is_video = media_type and media_type.startswith("video")
|
| 152 |
-
if is_video:
|
| 153 |
-
video, _, _ = read_video(video_path, output_format="TCHW")
|
| 154 |
-
video = video[:121] / 255.0
|
| 155 |
-
output_path = os.path.join("output", f"{uuid.uuid4()}.mp4")
|
| 156 |
-
else:
|
| 157 |
-
video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
|
| 158 |
-
output_path = os.path.join("output", f"{uuid.uuid4()}.png")
|
| 159 |
-
cond_latents = [transform(video.to("cuda"))]
|
| 160 |
-
ori_length = cond_latents[0].size(2)
|
| 161 |
-
cond_latents = runner.vae_encode(cond_latents)
|
| 162 |
-
samples = generation_step(runner, text_embeds, cond_latents)
|
| 163 |
-
sample = samples[0][:ori_length].cpu()
|
| 164 |
-
sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
|
| 165 |
-
if is_video:
|
| 166 |
-
mediapy.write_video(output_path, sample, fps=fps_out)
|
| 167 |
-
return None, output_path, output_path
|
| 168 |
-
else:
|
| 169 |
-
mediapy.write_image(output_path, sample[0])
|
| 170 |
-
return output_path, None, output_path
|
| 171 |
-
|
| 172 |
-
with gr.Blocks(title="SeedVR") as demo:
|
| 173 |
-
gr.HTML(f"""
|
| 174 |
-
<p><b>Demonstração oficial do Gradio</b> para
|
| 175 |
-
<a href='https://github.com/ByteDance-Seed/SeedVR' target='-blank'>
|
| 176 |
-
<b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
|
| 177 |
-
🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.
|
| 178 |
-
</p>
|
| 179 |
-
""")
|
| 180 |
-
with gr.Row():
|
| 181 |
-
input_file = gr.File(label="Carregar Imagem ou Vídeo")
|
| 182 |
-
with gr.Column():
|
| 183 |
-
seed = gr.Number(label="Seed", value=42)
|
| 184 |
-
fps = gr.Number(label="FPS de Saída", value=24)
|
| 185 |
-
run_button = gr.Button("Executar")
|
| 186 |
-
output_image = gr.Image(label="Imagem de Saída")
|
| 187 |
-
output_video = gr.Video(label="Vídeo de Saída")
|
| 188 |
-
download_link = gr.File(label="Baixar Resultado")
|
| 189 |
-
run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
|
| 190 |
-
|
| 191 |
-
demo.queue().launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|