Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -36,6 +36,7 @@ else:
|
|
| 36 |
from torchvision.transforms import Compose, Lambda, Normalize
|
| 37 |
from torchvision.io.video import read_video
|
| 38 |
import argparse
|
|
|
|
| 39 |
|
| 40 |
from common.distributed import (
|
| 41 |
get_device,
|
|
@@ -62,6 +63,7 @@ from urllib.parse import urlparse
|
|
| 62 |
from torch.hub import download_url_to_file, get_dir
|
| 63 |
import shlex
|
| 64 |
import uuid
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
@@ -76,7 +78,7 @@ subprocess.run(
|
|
| 76 |
)
|
| 77 |
|
| 78 |
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
| 79 |
-
"""Load file
|
| 80 |
|
| 81 |
Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
| 82 |
|
|
@@ -225,29 +227,6 @@ def generation_step(runner, text_embeds_dict, cond_latents):
|
|
| 225 |
@spaces.GPU(duration=100)
|
| 226 |
def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
|
| 227 |
runner = configure_runner(1)
|
| 228 |
-
output_dir = 'output/' + str(uuid.uuid4()) + '.mp4'
|
| 229 |
-
def _build_pos_and_neg_prompt():
|
| 230 |
-
# read positive prompt
|
| 231 |
-
positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \
|
| 232 |
-
hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \
|
| 233 |
-
skin pore detailing, hyper sharpness, perfect without deformations."
|
| 234 |
-
# read negative prompt
|
| 235 |
-
negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \
|
| 236 |
-
CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \
|
| 237 |
-
signature, jpeg artifacts, deformed, lowres, over-smooth"
|
| 238 |
-
return positive_text, negative_text
|
| 239 |
-
|
| 240 |
-
def _build_test_prompts(video_path):
|
| 241 |
-
positive_text, negative_text = _build_pos_and_neg_prompt()
|
| 242 |
-
original_videos = []
|
| 243 |
-
prompts = {}
|
| 244 |
-
video_list = os.listdir(video_path)
|
| 245 |
-
for f in video_list:
|
| 246 |
-
# if f.endswith(".mp4"):
|
| 247 |
-
original_videos.append(f)
|
| 248 |
-
prompts[f] = positive_text
|
| 249 |
-
print(f"Total prompts to be generated: {len(original_videos)}")
|
| 250 |
-
return original_videos, prompts, negative_text
|
| 251 |
|
| 252 |
def _extract_text_embeds():
|
| 253 |
# Text encoder forward.
|
|
@@ -294,7 +273,6 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
|
|
| 294 |
# set random seed
|
| 295 |
set_seed(seed, same_across_ranks=True)
|
| 296 |
os.makedirs('output/', exist_ok=True)
|
| 297 |
-
tgt_path = 'output/'
|
| 298 |
|
| 299 |
# get test prompts
|
| 300 |
original_videos = [video_path.split('/')[-1]]
|
|
@@ -331,13 +309,24 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
|
|
| 331 |
# read condition latents
|
| 332 |
cond_latents = []
|
| 333 |
for video in videos:
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
|
| 342 |
|
| 343 |
ori_lengths = [video.size(1) for video in cond_latents]
|
|
@@ -386,14 +375,20 @@ def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size
|
|
| 386 |
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
|
| 387 |
sample = sample.to(torch.uint8).numpy()
|
| 388 |
|
| 389 |
-
|
| 390 |
-
output_dir, sample
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
# print(f"Generated video size: {sample.shape}")
|
| 394 |
gc.collect()
|
| 395 |
torch.cuda.empty_cache()
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
| 397 |
|
| 398 |
|
| 399 |
with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training") as demo:
|
|
@@ -411,16 +406,17 @@ with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversar
|
|
| 411 |
|
| 412 |
# Interface
|
| 413 |
with gr.Row():
|
| 414 |
-
input_video = gr.
|
| 415 |
seed = gr.Number(label="Seeds", value=666)
|
| 416 |
fps = gr.Number(label="fps", value=24)
|
| 417 |
|
| 418 |
with gr.Row():
|
| 419 |
output_video = gr.Video(label="Output")
|
|
|
|
| 420 |
download_link = gr.File(label="Download the output")
|
| 421 |
|
| 422 |
run_button = gr.Button("Run")
|
| 423 |
-
run_button.click(fn=generation_loop, inputs=[input_video, seed, fps], outputs=[output_video, download_link])
|
| 424 |
|
| 425 |
# Examples
|
| 426 |
gr.Examples(
|
|
|
|
| 36 |
from torchvision.transforms import Compose, Lambda, Normalize
|
| 37 |
from torchvision.io.video import read_video
|
| 38 |
import argparse
|
| 39 |
+
from PIL import Image
|
| 40 |
|
| 41 |
from common.distributed import (
|
| 42 |
get_device,
|
|
|
|
| 63 |
from torch.hub import download_url_to_file, get_dir
|
| 64 |
import shlex
|
| 65 |
import uuid
|
| 66 |
+
import mimetypes
|
| 67 |
|
| 68 |
|
| 69 |
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
|
|
| 78 |
)
|
| 79 |
|
| 80 |
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
| 81 |
+
"""Load file from http url, will download models if necessary.
|
| 82 |
|
| 83 |
Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
| 84 |
|
|
|
|
| 227 |
@spaces.GPU(duration=100)
|
| 228 |
def generation_loop(video_path='./test_videos', seed=666, fps_out=12, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
|
| 229 |
runner = configure_runner(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
def _extract_text_embeds():
|
| 232 |
# Text encoder forward.
|
|
|
|
| 273 |
# set random seed
|
| 274 |
set_seed(seed, same_across_ranks=True)
|
| 275 |
os.makedirs('output/', exist_ok=True)
|
|
|
|
| 276 |
|
| 277 |
# get test prompts
|
| 278 |
original_videos = [video_path.split('/')[-1]]
|
|
|
|
| 309 |
# read condition latents
|
| 310 |
cond_latents = []
|
| 311 |
for video in videos:
|
| 312 |
+
media_type, _ = mimetypes.guess_type(video_path)
|
| 313 |
+
is_image = media_type and media_type.startswith("image")
|
| 314 |
+
is_video = media_type and media_type.startswith("video")
|
| 315 |
+
if is_video:
|
| 316 |
+
video = (
|
| 317 |
+
read_video(
|
| 318 |
+
os.path.join(video_path), output_format="TCHW"
|
| 319 |
+
)[0]
|
| 320 |
+
/ 255.0
|
| 321 |
+
)
|
| 322 |
+
print(f"Read video size: {video.size()}")
|
| 323 |
+
output_dir = 'output/' + str(uuid.uuid4()) + '.mp4'
|
| 324 |
+
else:
|
| 325 |
+
img = Image.open(input_file.name).convert("RGB")
|
| 326 |
+
img_tensor = T.ToTensor()(img).unsqueeze(0) # (1, C, H, W)
|
| 327 |
+
video = img_tensor.permute(0, 1, 2, 3) # (T=1, C, H, W)
|
| 328 |
+
print(f"Read Image size: {video.size()}")
|
| 329 |
+
output_dir = 'output/' + str(uuid.uuid4()) + '.png'
|
| 330 |
cond_latents.append(video_transform(video.to(torch.device("cuda"))))
|
| 331 |
|
| 332 |
ori_lengths = [video.size(1) for video in cond_latents]
|
|
|
|
| 375 |
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round()
|
| 376 |
sample = sample.to(torch.uint8).numpy()
|
| 377 |
|
| 378 |
+
if is_image:
|
| 379 |
+
mediapy.write(output_dir, sample[0])
|
| 380 |
+
else:
|
| 381 |
+
mediapy.write_video(
|
| 382 |
+
output_dir, sample, fps=fps_out
|
| 383 |
+
)
|
| 384 |
|
| 385 |
# print(f"Generated video size: {sample.shape}")
|
| 386 |
gc.collect()
|
| 387 |
torch.cuda.empty_cache()
|
| 388 |
+
if is_image:
|
| 389 |
+
return output_dir, None
|
| 390 |
+
else:
|
| 391 |
+
return None, output_dir
|
| 392 |
|
| 393 |
|
| 394 |
with gr.Blocks(title="SeedVR2: One-Step Video Restoration via Diffusion Adversarial Post-Training") as demo:
|
|
|
|
| 406 |
|
| 407 |
# Interface
|
| 408 |
with gr.Row():
|
| 409 |
+
input_video = gr.File(label="Upload image or video", type="file")
|
| 410 |
seed = gr.Number(label="Seeds", value=666)
|
| 411 |
fps = gr.Number(label="fps", value=24)
|
| 412 |
|
| 413 |
with gr.Row():
|
| 414 |
output_video = gr.Video(label="Output")
|
| 415 |
+
output_image = gr.Image(label="Output_Image")
|
| 416 |
download_link = gr.File(label="Download the output")
|
| 417 |
|
| 418 |
run_button = gr.Button("Run")
|
| 419 |
+
run_button.click(fn=generation_loop, inputs=[input_video, seed, fps], outputs=[output_image, output_video, download_link])
|
| 420 |
|
| 421 |
# Examples
|
| 422 |
gr.Examples(
|