Spaces:
Running
on
Zero
Running
on
Zero
updating app
Browse files- app.py +23 -26
- simplified_inference.py +72 -8
- simplified_validation.py +0 -108
app.py
CHANGED
|
@@ -5,7 +5,6 @@ import argparse
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
from PIL import Image
|
| 8 |
-
import skvideo
|
| 9 |
from diffusers.utils import export_to_video
|
| 10 |
|
| 11 |
from inference import load_model, inference_on_image
|
|
@@ -14,20 +13,17 @@ from inference import load_model, inference_on_image
|
|
| 14 |
# 1. Load model
|
| 15 |
# -----------------------
|
| 16 |
args = argparse.Namespace()
|
| 17 |
-
args.blur2vid_hf_repo_path = "tedlasai/
|
| 18 |
-
args.pretrained_model_path = "
|
| 19 |
-
args.
|
| 20 |
-
args.video_width = 1280
|
| 21 |
-
args.video_height = 720
|
| 22 |
-
args.seed = None
|
| 23 |
|
| 24 |
pipe, model_config = load_model(args)
|
| 25 |
|
| 26 |
-
OUTPUT_DIR = Path("/tmp/
|
| 27 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 28 |
|
| 29 |
|
| 30 |
-
def
|
| 31 |
"""
|
| 32 |
Wrapper for Gradio. Takes an image and returns a video path.
|
| 33 |
"""
|
|
@@ -60,16 +56,15 @@ def generate_video_from_image(image: Image.Image, interval_key: str, num_inferen
|
|
| 60 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
| 61 |
gr.Markdown(
|
| 62 |
"""
|
| 63 |
-
# 🖼️ ➜ 🎬
|
| 64 |
|
| 65 |
-
This demo accompanies the paper **“
|
| 66 |
-
by Tedla *et al.*,
|
| 67 |
|
| 68 |
-
- 🌐 **Project page:** <https://
|
| 69 |
-
- 💻 **Code:** <https://github.com/tedlasai/
|
| 70 |
|
| 71 |
-
Upload
|
| 72 |
-
Note: The image will be resized to 1280×720. We recommend uploading landscape-oriented images.
|
| 73 |
"""
|
| 74 |
)
|
| 75 |
|
|
@@ -82,35 +77,37 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
| 82 |
)
|
| 83 |
|
| 84 |
with gr.Row():
|
| 85 |
-
|
| 86 |
-
label="
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
interactive=True,
|
| 90 |
)
|
| 91 |
|
| 92 |
num_inference_steps = gr.Slider(
|
| 93 |
label="Number of inference steps",
|
| 94 |
minimum=4,
|
| 95 |
-
maximum=
|
| 96 |
step=1,
|
| 97 |
-
value=
|
| 98 |
info="More steps = better quality but slower",
|
| 99 |
)
|
| 100 |
|
| 101 |
-
generate_btn = gr.Button("Generate
|
| 102 |
|
| 103 |
with gr.Column():
|
| 104 |
video_out = gr.Video(
|
| 105 |
-
label="Generated
|
| 106 |
format="mp4",
|
| 107 |
autoplay=True,
|
| 108 |
loop=True,
|
| 109 |
)
|
| 110 |
|
| 111 |
generate_btn.click(
|
| 112 |
-
fn=
|
| 113 |
-
inputs=[image_in,
|
| 114 |
outputs=video_out,
|
| 115 |
api_name="predict",
|
| 116 |
)
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
from PIL import Image
|
|
|
|
| 8 |
from diffusers.utils import export_to_video
|
| 9 |
|
| 10 |
from inference import load_model, inference_on_image
|
|
|
|
| 13 |
# 1. Load model
|
| 14 |
# -----------------------
|
| 15 |
args = argparse.Namespace()
|
| 16 |
+
args.blur2vid_hf_repo_path = "tedlasai/learn2refocus"
|
| 17 |
+
args.pretrained_model_path = "stabilityai/stable-video-diffusion-img2vid"
|
| 18 |
+
args.seed = 0
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
pipe, model_config = load_model(args)
|
| 21 |
|
| 22 |
+
OUTPUT_DIR = Path("/tmp/output_stacks")
|
| 23 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 24 |
|
| 25 |
|
| 26 |
+
def generate_vstack_from_image(image: Image.Image, input_focal_position: int, num_inference_steps: int) -> str:
|
| 27 |
"""
|
| 28 |
Wrapper for Gradio. Takes an image and returns a video path.
|
| 29 |
"""
|
|
|
|
| 56 |
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
| 57 |
gr.Markdown(
|
| 58 |
"""
|
| 59 |
+
# 🖼️ ➜ 🎬 Generate Focal Stacks from a Single Image
|
| 60 |
|
| 61 |
+
This demo accompanies the paper **“Learning to Refocus with Video Diffusion MOdels”**
|
| 62 |
+
by Tedla *et al.*, SIGGRAPH Asia 2025.
|
| 63 |
|
| 64 |
+
- 🌐 **Project page:** <https://learn2refocus.github.io/>
|
| 65 |
+
- 💻 **Code:** <https://github.com/tedlasai/learn2refocus/>
|
| 66 |
|
| 67 |
+
Upload an image specify the input focal position. Near - 5cm, Far - Infinity. Then, click "Generate stack" to generate a focal stack.
|
|
|
|
| 68 |
"""
|
| 69 |
)
|
| 70 |
|
|
|
|
| 77 |
)
|
| 78 |
|
| 79 |
with gr.Row():
|
| 80 |
+
input_focal_position = gr.Slider(
|
| 81 |
+
label="Input focal position (Near - 5cm, Far - Infinity):",
|
| 82 |
+
minimum=0,
|
| 83 |
+
maximum=8,
|
| 84 |
+
step=1,
|
| 85 |
+
value=4,
|
| 86 |
interactive=True,
|
| 87 |
)
|
| 88 |
|
| 89 |
num_inference_steps = gr.Slider(
|
| 90 |
label="Number of inference steps",
|
| 91 |
minimum=4,
|
| 92 |
+
maximum=25,
|
| 93 |
step=1,
|
| 94 |
+
value=25,
|
| 95 |
info="More steps = better quality but slower",
|
| 96 |
)
|
| 97 |
|
| 98 |
+
generate_btn = gr.Button("Generate stack", variant="primary")
|
| 99 |
|
| 100 |
with gr.Column():
|
| 101 |
video_out = gr.Video(
|
| 102 |
+
label="Generated stack",
|
| 103 |
format="mp4",
|
| 104 |
autoplay=True,
|
| 105 |
loop=True,
|
| 106 |
)
|
| 107 |
|
| 108 |
generate_btn.click(
|
| 109 |
+
fn=generate_vstack_from_image,
|
| 110 |
+
inputs=[image_in, input_focal_position, num_inference_steps],
|
| 111 |
outputs=video_out,
|
| 112 |
api_name="predict",
|
| 113 |
)
|
simplified_inference.py
CHANGED
|
@@ -18,20 +18,20 @@
|
|
| 18 |
|
| 19 |
import math
|
| 20 |
import os
|
| 21 |
-
from torch.utils.data import Dataset
|
| 22 |
-
import accelerate
|
| 23 |
import numpy as np
|
| 24 |
import torch
|
| 25 |
-
import torch.nn.functional as F
|
| 26 |
import torch.utils.checkpoint
|
| 27 |
from accelerate.logging import get_logger
|
| 28 |
from accelerate.utils import set_seed
|
| 29 |
-
from packaging import version
|
| 30 |
from tqdm.auto import tqdm
|
| 31 |
from transformers import CLIPVisionModelWithProjection
|
| 32 |
-
from simplified_validation import valid_net
|
| 33 |
from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
| 34 |
from diffusers.utils import check_min_version
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
import argparse
|
| 36 |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 37 |
check_min_version("0.24.0.dev0")
|
|
@@ -40,8 +40,6 @@ logger = get_logger(__name__, log_level="INFO")
|
|
| 40 |
import numpy as np
|
| 41 |
import torch
|
| 42 |
import os
|
| 43 |
-
import glob
|
| 44 |
-
|
| 45 |
|
| 46 |
|
| 47 |
def parse_args():
|
|
@@ -150,6 +148,68 @@ def convert_to_batch(image, input_focal_position, sample_frames=9):
|
|
| 150 |
name = os.path.splitext(os.path.basename(scene))[0]
|
| 151 |
return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
def main():
|
| 154 |
args = parse_args()
|
| 155 |
|
|
@@ -182,7 +242,11 @@ def main():
|
|
| 182 |
|
| 183 |
unet.eval(); image_encoder.eval(); vae.eval()
|
| 184 |
with torch.no_grad():
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
if __name__ == "__main__":
|
| 188 |
main()
|
|
|
|
| 18 |
|
| 19 |
import math
|
| 20 |
import os
|
|
|
|
|
|
|
| 21 |
import numpy as np
|
| 22 |
import torch
|
|
|
|
| 23 |
import torch.utils.checkpoint
|
| 24 |
from accelerate.logging import get_logger
|
| 25 |
from accelerate.utils import set_seed
|
|
|
|
| 26 |
from tqdm.auto import tqdm
|
| 27 |
from transformers import CLIPVisionModelWithProjection
|
|
|
|
| 28 |
from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
| 29 |
from diffusers.utils import check_min_version
|
| 30 |
+
from simplified_pipeline import StableVideoDiffusionPipeline
|
| 31 |
+
import videoio
|
| 32 |
+
from PIL import Image
|
| 33 |
+
|
| 34 |
+
|
| 35 |
import argparse
|
| 36 |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 37 |
check_min_version("0.24.0.dev0")
|
|
|
|
| 40 |
import numpy as np
|
| 41 |
import torch
|
| 42 |
import os
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def parse_args():
|
|
|
|
| 148 |
name = os.path.splitext(os.path.basename(scene))[0]
|
| 149 |
return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
|
| 150 |
|
| 151 |
+
|
| 152 |
+
def inference_on_image(args, batch, unet, image_encoder, vae, global_step, weight_dtype, device):
|
| 153 |
+
|
| 154 |
+
pipeline = StableVideoDiffusionPipeline.from_pretrained(
|
| 155 |
+
args.pretrained_model_path,
|
| 156 |
+
unet=unet,
|
| 157 |
+
image_encoder=image_encoder,
|
| 158 |
+
vae=vae,
|
| 159 |
+
torch_dtype=weight_dtype,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 163 |
+
num_frames = 9
|
| 164 |
+
unet.eval()
|
| 165 |
+
|
| 166 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 167 |
+
focal_stack_num = batch["focal_stack_num"]
|
| 168 |
+
|
| 169 |
+
svd_output, _ = pipeline(
|
| 170 |
+
pixel_values,
|
| 171 |
+
height=pixel_values.shape[3],
|
| 172 |
+
width=pixel_values.shape[4],
|
| 173 |
+
num_frames=num_frames,
|
| 174 |
+
decode_chunk_size=8,
|
| 175 |
+
motion_bucket_id=0,
|
| 176 |
+
min_guidance_scale=1.5,
|
| 177 |
+
max_guidance_scale=1.5,
|
| 178 |
+
fps=7,
|
| 179 |
+
noise_aug_strength=0,
|
| 180 |
+
focal_stack_num = focal_stack_num,
|
| 181 |
+
num_inference_steps=args.num_inference_steps,
|
| 182 |
+
)
|
| 183 |
+
video_frames = svd_output.frames[0]
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
video_frames_normalized = video_frames*0.5 + 0.5
|
| 187 |
+
video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
|
| 188 |
+
video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
|
| 189 |
+
video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
|
| 190 |
+
|
| 191 |
+
return video_frames_normalized, focal_stack_num
|
| 192 |
+
# run inference
|
| 193 |
+
def write_output(output_dir, frames, focal_stack_num, icc_profile):
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
print("Validation images will be saved to ", output_dir)
|
| 197 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 198 |
+
|
| 199 |
+
videoio.videosave(os.path.join(
|
| 200 |
+
output_dir,
|
| 201 |
+
f"stack.mp4",
|
| 202 |
+
), frames.permute(0,2,3,1).cpu().numpy(), fps=5)
|
| 203 |
+
|
| 204 |
+
#save images
|
| 205 |
+
for i in range(9):
|
| 206 |
+
#use Pillow to save images
|
| 207 |
+
img = Image.fromarray((frames[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
|
| 208 |
+
if icc_profile != "none":
|
| 209 |
+
img.info['icc_profile'] = icc_profile
|
| 210 |
+
img.save(os.path.join(output_dir, f"frame_{i}.png"))
|
| 211 |
+
|
| 212 |
+
|
| 213 |
def main():
|
| 214 |
args = parse_args()
|
| 215 |
|
|
|
|
| 242 |
|
| 243 |
unet.eval(); image_encoder.eval(); vae.eval()
|
| 244 |
with torch.no_grad():
|
| 245 |
+
output_frames, focal_stack_num = inference_on_image(args, batch, unet, image_encoder, vae, 0, weight_dtype, device)
|
| 246 |
+
val_save_dir = os.path.join(args.output_dir, "validation_images", batch['name'])
|
| 247 |
+
write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
|
| 248 |
+
|
| 249 |
+
|
| 250 |
|
| 251 |
if __name__ == "__main__":
|
| 252 |
main()
|
simplified_validation.py
DELETED
|
@@ -1,108 +0,0 @@
|
|
| 1 |
-
from simplified_pipeline import StableVideoDiffusionPipeline
|
| 2 |
-
import os
|
| 3 |
-
import torch
|
| 4 |
-
import numpy as np
|
| 5 |
-
import videoio
|
| 6 |
-
import matplotlib.image
|
| 7 |
-
from PIL import Image
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def valid_net(args, batch, unet, image_encoder, vae, global_step, weight_dtype, device):
|
| 12 |
-
|
| 13 |
-
# The models need unwrapping because for compatibility in distributed training mode.
|
| 14 |
-
|
| 15 |
-
pipeline = StableVideoDiffusionPipeline.from_pretrained(
|
| 16 |
-
args.pretrained_model_path,
|
| 17 |
-
unet=unet,
|
| 18 |
-
image_encoder=image_encoder,
|
| 19 |
-
vae=vae,
|
| 20 |
-
torch_dtype=weight_dtype,
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
pipeline.set_progress_bar_config(disable=True)
|
| 24 |
-
|
| 25 |
-
# run inference
|
| 26 |
-
val_save_dir = os.path.join(
|
| 27 |
-
args.output_dir, "validation_images")
|
| 28 |
-
|
| 29 |
-
print("Validation images will be saved to ", val_save_dir)
|
| 30 |
-
|
| 31 |
-
os.makedirs(val_save_dir, exist_ok=True)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
num_frames = 9
|
| 35 |
-
unet.eval()
|
| 36 |
-
|
| 37 |
-
#clear gradients (the torch no grad is the magic that makes this work)
|
| 38 |
-
with torch.no_grad():
|
| 39 |
-
torch.cuda.empty_cache()
|
| 40 |
-
|
| 41 |
-
pixel_values = batch["pixel_values"].to(device)
|
| 42 |
-
original_pixel_values = batch['original_pixel_values'].to(device)
|
| 43 |
-
focal_stack_num = batch["focal_stack_num"]
|
| 44 |
-
|
| 45 |
-
svd_output, gt_frames = pipeline(
|
| 46 |
-
pixel_values,
|
| 47 |
-
height=pixel_values.shape[3],
|
| 48 |
-
width=pixel_values.shape[4],
|
| 49 |
-
num_frames=num_frames,
|
| 50 |
-
decode_chunk_size=8,
|
| 51 |
-
motion_bucket_id=0,
|
| 52 |
-
min_guidance_scale=1.5,
|
| 53 |
-
max_guidance_scale=1.5,
|
| 54 |
-
fps=7,
|
| 55 |
-
noise_aug_strength=0,
|
| 56 |
-
focal_stack_num = focal_stack_num,
|
| 57 |
-
num_inference_steps=args.num_inference_steps,
|
| 58 |
-
)
|
| 59 |
-
video_frames = svd_output.frames[0]
|
| 60 |
-
gt_frames = gt_frames[0]
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
with torch.no_grad():
|
| 64 |
-
|
| 65 |
-
if len(original_pixel_values.shape) == 5:
|
| 66 |
-
pixel_values = original_pixel_values[0] #assuming batch size is 1
|
| 67 |
-
else:
|
| 68 |
-
pixel_values = original_pixel_values.repeat(num_frames, 1, 1, 1)
|
| 69 |
-
pixel_values_normalized = pixel_values*0.5 + 0.5
|
| 70 |
-
pixel_values_normalized = torch.clamp(pixel_values_normalized,0,1)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
video_frames_normalized = video_frames*0.5 + 0.5
|
| 76 |
-
video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
|
| 77 |
-
video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
gt_frames = torch.clamp(gt_frames,0,1)
|
| 81 |
-
gt_frames = gt_frames.permute(1,0,2,3)
|
| 82 |
-
|
| 83 |
-
#RESIZE images
|
| 84 |
-
video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
|
| 85 |
-
gt_frames = torch.nn.functional.interpolate(gt_frames, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
|
| 86 |
-
pixel_values_normalized = torch.nn.functional.interpolate(pixel_values_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
|
| 87 |
-
|
| 88 |
-
os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/videos"), exist_ok=True)
|
| 89 |
-
videoio.videosave(os.path.join(
|
| 90 |
-
val_save_dir,
|
| 91 |
-
f"position_{focal_stack_num}/videos/{batch['name']}.mp4",
|
| 92 |
-
), video_frames_normalized.permute(0,2,3,1).cpu().numpy(), fps=5)
|
| 93 |
-
|
| 94 |
-
#save images
|
| 95 |
-
os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/images"), exist_ok=True)
|
| 96 |
-
for i in range(num_frames):
|
| 97 |
-
#use Pillow to save images
|
| 98 |
-
img = Image.fromarray((video_frames_normalized[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
|
| 99 |
-
#use index to assign icc profile to img
|
| 100 |
-
if batch['icc_profile'] != "none":
|
| 101 |
-
img.info['icc_profile'] = batch['icc_profile']
|
| 102 |
-
path = os.path.join(val_save_dir, f"position_{focal_stack_num}/images/{batch['name']}_frame_{i}.png")
|
| 103 |
-
print("Saving image to ", path)
|
| 104 |
-
img.save(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/{batch['name']}_frame_{i}.png"))
|
| 105 |
-
del video_frames
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|