aducsdr commited on
Commit
19ee49f
·
verified ·
1 Parent(s): 766a8a1

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -191
app.py DELETED
@@ -1,191 +0,0 @@
1
- import os
2
- import sys
3
- import subprocess
4
- import importlib.util
5
-
6
- # --- ETAPA 0: Instalação Final do flash-attn ---
7
-
8
- # Verifica se o flash_attn já está instalado. Se não, instala.
9
- package_name = 'flash_attn'
10
- spec = importlib.util.find_spec(package_name)
11
- if spec is None:
12
- print(f"Instalando o pacote que faltava: {package_name}. Isso pode levar um minuto...")
13
- # Usamos o python executável do ambiente atual para instalar o pacote
14
- python_executable = sys.executable
15
- subprocess.run(
16
- [
17
- python_executable, "-m", "pip", "install",
18
- "flash_attn==2.5.9.post1",
19
- "--no-build-isolation"
20
- ],
21
- check=True
22
- )
23
- print(f"✅ {package_name} instalado com sucesso.")
24
- else:
25
- print(f"✅ Pacote {package_name} já está instalado.")
26
-
27
-
28
- # A partir daqui, o ambiente está 100% pronto.
29
- # ---------------------------------------------------------------------
30
-
31
- import spaces
32
- from pathlib import Path
33
- from urllib.parse import urlparse
34
- import torch
35
- from torch.hub import download_url_to_file
36
- import mediapy
37
- from einops import rearrange
38
- from omegaconf import OmegaConf
39
- import datetime
40
- import gc
41
- from PIL import Image
42
- import gradio as gr
43
- import uuid
44
- import mimetypes
45
- import torchvision.transforms as T
46
- from torchvision.transforms import Compose, Lambda, Normalize
47
- from torchvision.io.video import read_video
48
-
49
- # --- ETAPA 1: Clonar o Repositório e Mudar para o Diretório ---
50
- repo_name = "SeedVR"
51
- if not os.path.exists(repo_name):
52
- print(f"Clonando o repositório {repo_name} do GitHub...")
53
- subprocess.run(f"git clone https://github.com/ByteDance-Seed/{repo_name}.git", shell=True, check=True)
54
-
55
- # Garante que estamos no diretório certo
56
- if not os.getcwd().endswith(repo_name):
57
- os.chdir(repo_name)
58
-
59
- sys.path.insert(0, os.path.abspath('.'))
60
-
61
- # Importações do projeto SeedVR (só podem ser feitas após o chdir)
62
- from data.image.transforms.divisible_crop import DivisibleCrop
63
- from data.image.transforms.na_resize import NaResize
64
- from data.video.transforms.rearrange import Rearrange
65
- from common.config import load_config
66
- from common.distributed import init_torch
67
- from common.seed import set_seed
68
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
69
- from common.distributed.ops import sync_data
70
-
71
- print("Ambiente Conda carregado e verificado. Iniciando a aplicação...")
72
-
73
- # --- ETAPA 2: Baixar os Modelos Pré-treinados ---
74
- print("Baixando modelos pré-treinados...")
75
-
76
- def load_file_from_url(url, model_dir='.', progress=True, file_name=None):
77
- os.makedirs(model_dir, exist_ok=True)
78
- if not file_name:
79
- parts = urlparse(url)
80
- file_name = os.path.basename(parts.path)
81
- cached_file = os.path.join(model_dir, file_name)
82
- if not os.path.exists(cached_file):
83
- print(f'Baixando: "{url}" para {cached_file}\n')
84
- download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
85
- return cached_file
86
-
87
- pretrain_model_url = {
88
- 'vae': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/ema_vae.pth',
89
- 'dit': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/seedvr2_ema_3b.pth',
90
- 'pos_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/pos_emb.pt',
91
- 'neg_emb': 'https://huggingface.co/ByteDance-Seed/SeedVR2-3B/resolve/main/neg_emb.pt',
92
- }
93
-
94
- Path('./ckpts').mkdir(exist_ok=True)
95
- for key, url in pretrain_model_url.items():
96
- model_dir = './ckpts' if key in ['vae', 'dit'] else '.'
97
- load_file_from_url(url=url, model_dir=model_dir)
98
-
99
- # --- ETAPA 3: Executar a Aplicação Principal ---
100
- os.environ["MASTER_ADDR"] = "127.0.0.1"
101
- os.environ["MASTER_PORT"] = "12355"
102
- os.environ["RANK"] = str(0)
103
- os.environ["WORLD_SIZE"] = str(1)
104
-
105
- def configure_runner():
106
- config = load_config('configs_3b/main.yaml')
107
- runner = VideoDiffusionInfer(config)
108
- OmegaConf.set_readonly(runner.config, False)
109
- init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
110
- runner.configure_dit_model(device="cuda", checkpoint='ckpts/seedvr2_ema_3b.pth')
111
- runner.configure_vae_model()
112
- if hasattr(runner.vae, "set_memory_limit"):
113
- runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
114
- return runner
115
-
116
- def generation_step(runner, text_embeds_dict, cond_latents):
117
- def _move_to_cuda(x): return [i.to("cuda") for i in x]
118
- noises, aug_noises = [torch.randn_like(l) for l in cond_latents], [torch.randn_like(l) for l in cond_latents]
119
- noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
120
- noises, aug_noises, cond_latents = map(_move_to_cuda, (noises, aug_noises, cond_latents))
121
- def _add_noise(x, aug_noise):
122
- t = torch.tensor([100.0], device="cuda")
123
- shape = torch.tensor(x.shape[1:], device="cuda")[None]
124
- t = runner.timestep_transform(t, shape)
125
- return runner.schedule.forward(x, aug_noise, t)
126
- conditions = [runner.get_condition(n, task="sr", latent_blur=_add_noise(l, an)) for n, an, l in zip(noises, aug_noises, cond_latents)]
127
- with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
128
- video_tensors = runner.inference(noises=noises, conditions=conditions, **text_embeds_dict)
129
- return [rearrange(v, "c t h w -> t c h w") for v in video_tensors]
130
-
131
- @spaces.GPU
132
- def generation_loop(video_path, seed=666, fps_out=24):
133
- if video_path is None: return None, None, None
134
- runner = configure_runner()
135
- text_embeds = {
136
- "texts_pos": [torch.load('pos_emb.pt', weights_only=True).to("cuda")],
137
- "texts_neg": [torch.load('neg_emb.pt', weights_only=True).to("cuda")]
138
- }
139
- runner.configure_diffusion()
140
- set_seed(int(seed))
141
- os.makedirs("output", exist_ok=True)
142
- res_h, res_w = 1280, 720
143
- transform = Compose([
144
- NaResize(resolution=(res_h * res_w)**0.5, mode="area", downsample_only=False),
145
- Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
146
- DivisibleCrop((16, 16)),
147
- Normalize(0.5, 0.5),
148
- Rearrange("t c h w -> c t h w")
149
- ])
150
- media_type, _ = mimetypes.guess_type(video_path)
151
- is_video = media_type and media_type.startswith("video")
152
- if is_video:
153
- video, _, _ = read_video(video_path, output_format="TCHW")
154
- video = video[:121] / 255.0
155
- output_path = os.path.join("output", f"{uuid.uuid4()}.mp4")
156
- else:
157
- video = T.ToTensor()(Image.open(video_path).convert("RGB")).unsqueeze(0)
158
- output_path = os.path.join("output", f"{uuid.uuid4()}.png")
159
- cond_latents = [transform(video.to("cuda"))]
160
- ori_length = cond_latents[0].size(2)
161
- cond_latents = runner.vae_encode(cond_latents)
162
- samples = generation_step(runner, text_embeds, cond_latents)
163
- sample = samples[0][:ori_length].cpu()
164
- sample = rearrange(sample, "t c h w -> t h w c").clip(-1, 1).add(1).mul(127.5).byte().numpy()
165
- if is_video:
166
- mediapy.write_video(output_path, sample, fps=fps_out)
167
- return None, output_path, output_path
168
- else:
169
- mediapy.write_image(output_path, sample[0])
170
- return output_path, None, output_path
171
-
172
- with gr.Blocks(title="SeedVR") as demo:
173
- gr.HTML(f"""
174
- <p><b>Demonstração oficial do Gradio</b> para
175
- <a href='https://github.com/ByteDance-Seed/SeedVR' target='-blank'>
176
- <b>SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training</b></a>.<br>
177
- 🔥 <b>SeedVR2</b> é um algoritmo de restauração de imagem e vídeo em um passo para conteúdo do mundo real e AIGC.
178
- </p>
179
- """)
180
- with gr.Row():
181
- input_file = gr.File(label="Carregar Imagem ou Vídeo")
182
- with gr.Column():
183
- seed = gr.Number(label="Seed", value=42)
184
- fps = gr.Number(label="FPS de Saída", value=24)
185
- run_button = gr.Button("Executar")
186
- output_image = gr.Image(label="Imagem de Saída")
187
- output_video = gr.Video(label="Vídeo de Saída")
188
- download_link = gr.File(label="Baixar Resultado")
189
- run_button.click(fn=generation_loop, inputs=[input_file, seed, fps], outputs=[output_image, output_video, download_link])
190
-
191
- demo.queue().launch(share=True)