tedlasai commited on
Commit
cc63be8
·
1 Parent(s): c8862b3

updating app

Browse files
Files changed (3) hide show
  1. app.py +23 -26
  2. simplified_inference.py +72 -8
  3. simplified_validation.py +0 -108
app.py CHANGED
@@ -5,7 +5,6 @@ import argparse
5
 
6
  import gradio as gr
7
  from PIL import Image
8
- import skvideo
9
  from diffusers.utils import export_to_video
10
 
11
  from inference import load_model, inference_on_image
@@ -14,20 +13,17 @@ from inference import load_model, inference_on_image
14
  # 1. Load model
15
  # -----------------------
16
  args = argparse.Namespace()
17
- args.blur2vid_hf_repo_path = "tedlasai/blur2vid"
18
- args.pretrained_model_path = "THUDM/CogVideoX-2b"
19
- args.model_config_path = "training/configs/outsidephotos.yaml"
20
- args.video_width = 1280
21
- args.video_height = 720
22
- args.seed = None
23
 
24
  pipe, model_config = load_model(args)
25
 
26
- OUTPUT_DIR = Path("/tmp/generated_videos")
27
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
28
 
29
 
30
- def generate_video_from_image(image: Image.Image, interval_key: str, num_inference_steps: int) -> str:
31
  """
32
  Wrapper for Gradio. Takes an image and returns a video path.
33
  """
@@ -60,16 +56,15 @@ def generate_video_from_image(image: Image.Image, interval_key: str, num_inferen
60
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
61
  gr.Markdown(
62
  """
63
- # 🖼️ ➜ 🎬 Recover Motion from a Blurry Image
64
 
65
- This demo accompanies the paper **“Generating the Past, Present, and Future from a Motion-Blurred Image”**
66
- by Tedla *et al.*, ACM Transactions on Graphics (SIGGRAPH Asia 2025).
67
 
68
- - 🌐 **Project page:** <https://blur2vid.github.io/>
69
- - 💻 **Code:** <https://github.com/tedlasai/blur2vid/>
70
 
71
- Upload a blurry image and the model will generate a short video showing the recovered motion based on your selection.
72
- Note: The image will be resized to 1280×720. We recommend uploading landscape-oriented images.
73
  """
74
  )
75
 
@@ -82,35 +77,37 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
82
  )
83
 
84
  with gr.Row():
85
- tense_choice = gr.Radio(
86
- label="Select the interval to be generated:",
87
- choices=["present", "past, present and future"],
88
- value="past, present and future",
 
 
89
  interactive=True,
90
  )
91
 
92
  num_inference_steps = gr.Slider(
93
  label="Number of inference steps",
94
  minimum=4,
95
- maximum=50,
96
  step=1,
97
- value=20,
98
  info="More steps = better quality but slower",
99
  )
100
 
101
- generate_btn = gr.Button("Generate video", variant="primary")
102
 
103
  with gr.Column():
104
  video_out = gr.Video(
105
- label="Generated video",
106
  format="mp4",
107
  autoplay=True,
108
  loop=True,
109
  )
110
 
111
  generate_btn.click(
112
- fn=generate_video_from_image,
113
- inputs=[image_in, tense_choice, num_inference_steps],
114
  outputs=video_out,
115
  api_name="predict",
116
  )
 
5
 
6
  import gradio as gr
7
  from PIL import Image
 
8
  from diffusers.utils import export_to_video
9
 
10
  from inference import load_model, inference_on_image
 
13
  # 1. Load model
14
  # -----------------------
15
  args = argparse.Namespace()
16
+ args.blur2vid_hf_repo_path = "tedlasai/learn2refocus"
17
+ args.pretrained_model_path = "stabilityai/stable-video-diffusion-img2vid"
18
+ args.seed = 0
 
 
 
19
 
20
  pipe, model_config = load_model(args)
21
 
22
+ OUTPUT_DIR = Path("/tmp/output_stacks")
23
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
24
 
25
 
26
+ def generate_vstack_from_image(image: Image.Image, input_focal_position: int, num_inference_steps: int) -> str:
27
  """
28
  Wrapper for Gradio. Takes an image and returns a video path.
29
  """
 
56
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
57
  gr.Markdown(
58
  """
59
+ # 🖼️ ➜ 🎬 Generate Focal Stacks from a Single Image
60
 
61
+ This demo accompanies the paper **“Learning to Refocus with Video Diffusion MOdels”**
62
+ by Tedla *et al.*, SIGGRAPH Asia 2025.
63
 
64
+ - 🌐 **Project page:** <https://learn2refocus.github.io/>
65
+ - 💻 **Code:** <https://github.com/tedlasai/learn2refocus/>
66
 
67
+ Upload an image specify the input focal position. Near - 5cm, Far - Infinity. Then, click "Generate stack" to generate a focal stack.
 
68
  """
69
  )
70
 
 
77
  )
78
 
79
  with gr.Row():
80
+ input_focal_position = gr.Slider(
81
+ label="Input focal position (Near - 5cm, Far - Infinity):",
82
+ minimum=0,
83
+ maximum=8,
84
+ step=1,
85
+ value=4,
86
  interactive=True,
87
  )
88
 
89
  num_inference_steps = gr.Slider(
90
  label="Number of inference steps",
91
  minimum=4,
92
+ maximum=25,
93
  step=1,
94
+ value=25,
95
  info="More steps = better quality but slower",
96
  )
97
 
98
+ generate_btn = gr.Button("Generate stack", variant="primary")
99
 
100
  with gr.Column():
101
  video_out = gr.Video(
102
+ label="Generated stack",
103
  format="mp4",
104
  autoplay=True,
105
  loop=True,
106
  )
107
 
108
  generate_btn.click(
109
+ fn=generate_vstack_from_image,
110
+ inputs=[image_in, input_focal_position, num_inference_steps],
111
  outputs=video_out,
112
  api_name="predict",
113
  )
simplified_inference.py CHANGED
@@ -18,20 +18,20 @@
18
 
19
  import math
20
  import os
21
- from torch.utils.data import Dataset
22
- import accelerate
23
  import numpy as np
24
  import torch
25
- import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from accelerate.logging import get_logger
28
  from accelerate.utils import set_seed
29
- from packaging import version
30
  from tqdm.auto import tqdm
31
  from transformers import CLIPVisionModelWithProjection
32
- from simplified_validation import valid_net
33
  from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
34
  from diffusers.utils import check_min_version
 
 
 
 
 
35
  import argparse
36
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
37
  check_min_version("0.24.0.dev0")
@@ -40,8 +40,6 @@ logger = get_logger(__name__, log_level="INFO")
40
  import numpy as np
41
  import torch
42
  import os
43
- import glob
44
-
45
 
46
 
47
  def parse_args():
@@ -150,6 +148,68 @@ def convert_to_batch(image, input_focal_position, sample_frames=9):
150
  name = os.path.splitext(os.path.basename(scene))[0]
151
  return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def main():
154
  args = parse_args()
155
 
@@ -182,7 +242,11 @@ def main():
182
 
183
  unet.eval(); image_encoder.eval(); vae.eval()
184
  with torch.no_grad():
185
- valid_net(args, batch, unet, image_encoder, vae, 0, weight_dtype, device, num_inference_steps=args.num_inference_steps)
 
 
 
 
186
 
187
  if __name__ == "__main__":
188
  main()
 
18
 
19
  import math
20
  import os
 
 
21
  import numpy as np
22
  import torch
 
23
  import torch.utils.checkpoint
24
  from accelerate.logging import get_logger
25
  from accelerate.utils import set_seed
 
26
  from tqdm.auto import tqdm
27
  from transformers import CLIPVisionModelWithProjection
 
28
  from diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
29
  from diffusers.utils import check_min_version
30
+ from simplified_pipeline import StableVideoDiffusionPipeline
31
+ import videoio
32
+ from PIL import Image
33
+
34
+
35
  import argparse
36
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
37
  check_min_version("0.24.0.dev0")
 
40
  import numpy as np
41
  import torch
42
  import os
 
 
43
 
44
 
45
  def parse_args():
 
148
  name = os.path.splitext(os.path.basename(scene))[0]
149
  return {"pixel_values": pixels, "focal_stack_num": focal_stack_num, "original_pixel_values": original_pixels, 'icc_profile': icc_profile, "name": name}
150
 
151
+
152
+ def inference_on_image(args, batch, unet, image_encoder, vae, global_step, weight_dtype, device):
153
+
154
+ pipeline = StableVideoDiffusionPipeline.from_pretrained(
155
+ args.pretrained_model_path,
156
+ unet=unet,
157
+ image_encoder=image_encoder,
158
+ vae=vae,
159
+ torch_dtype=weight_dtype,
160
+ )
161
+
162
+ pipeline.set_progress_bar_config(disable=True)
163
+ num_frames = 9
164
+ unet.eval()
165
+
166
+ pixel_values = batch["pixel_values"].to(device)
167
+ focal_stack_num = batch["focal_stack_num"]
168
+
169
+ svd_output, _ = pipeline(
170
+ pixel_values,
171
+ height=pixel_values.shape[3],
172
+ width=pixel_values.shape[4],
173
+ num_frames=num_frames,
174
+ decode_chunk_size=8,
175
+ motion_bucket_id=0,
176
+ min_guidance_scale=1.5,
177
+ max_guidance_scale=1.5,
178
+ fps=7,
179
+ noise_aug_strength=0,
180
+ focal_stack_num = focal_stack_num,
181
+ num_inference_steps=args.num_inference_steps,
182
+ )
183
+ video_frames = svd_output.frames[0]
184
+
185
+
186
+ video_frames_normalized = video_frames*0.5 + 0.5
187
+ video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
188
+ video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
189
+ video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
190
+
191
+ return video_frames_normalized, focal_stack_num
192
+ # run inference
193
+ def write_output(output_dir, frames, focal_stack_num, icc_profile):
194
+
195
+
196
+ print("Validation images will be saved to ", output_dir)
197
+ os.makedirs(output_dir, exist_ok=True)
198
+
199
+ videoio.videosave(os.path.join(
200
+ output_dir,
201
+ f"stack.mp4",
202
+ ), frames.permute(0,2,3,1).cpu().numpy(), fps=5)
203
+
204
+ #save images
205
+ for i in range(9):
206
+ #use Pillow to save images
207
+ img = Image.fromarray((frames[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
208
+ if icc_profile != "none":
209
+ img.info['icc_profile'] = icc_profile
210
+ img.save(os.path.join(output_dir, f"frame_{i}.png"))
211
+
212
+
213
  def main():
214
  args = parse_args()
215
 
 
242
 
243
  unet.eval(); image_encoder.eval(); vae.eval()
244
  with torch.no_grad():
245
+ output_frames, focal_stack_num = inference_on_image(args, batch, unet, image_encoder, vae, 0, weight_dtype, device)
246
+ val_save_dir = os.path.join(args.output_dir, "validation_images", batch['name'])
247
+ write_output(val_save_dir, output_frames, focal_stack_num, batch['icc_profile'])
248
+
249
+
250
 
251
  if __name__ == "__main__":
252
  main()
simplified_validation.py DELETED
@@ -1,108 +0,0 @@
1
- from simplified_pipeline import StableVideoDiffusionPipeline
2
- import os
3
- import torch
4
- import numpy as np
5
- import videoio
6
- import matplotlib.image
7
- from PIL import Image
8
-
9
-
10
-
11
- def valid_net(args, batch, unet, image_encoder, vae, global_step, weight_dtype, device):
12
-
13
- # The models need unwrapping because for compatibility in distributed training mode.
14
-
15
- pipeline = StableVideoDiffusionPipeline.from_pretrained(
16
- args.pretrained_model_path,
17
- unet=unet,
18
- image_encoder=image_encoder,
19
- vae=vae,
20
- torch_dtype=weight_dtype,
21
- )
22
-
23
- pipeline.set_progress_bar_config(disable=True)
24
-
25
- # run inference
26
- val_save_dir = os.path.join(
27
- args.output_dir, "validation_images")
28
-
29
- print("Validation images will be saved to ", val_save_dir)
30
-
31
- os.makedirs(val_save_dir, exist_ok=True)
32
-
33
-
34
- num_frames = 9
35
- unet.eval()
36
-
37
- #clear gradients (the torch no grad is the magic that makes this work)
38
- with torch.no_grad():
39
- torch.cuda.empty_cache()
40
-
41
- pixel_values = batch["pixel_values"].to(device)
42
- original_pixel_values = batch['original_pixel_values'].to(device)
43
- focal_stack_num = batch["focal_stack_num"]
44
-
45
- svd_output, gt_frames = pipeline(
46
- pixel_values,
47
- height=pixel_values.shape[3],
48
- width=pixel_values.shape[4],
49
- num_frames=num_frames,
50
- decode_chunk_size=8,
51
- motion_bucket_id=0,
52
- min_guidance_scale=1.5,
53
- max_guidance_scale=1.5,
54
- fps=7,
55
- noise_aug_strength=0,
56
- focal_stack_num = focal_stack_num,
57
- num_inference_steps=args.num_inference_steps,
58
- )
59
- video_frames = svd_output.frames[0]
60
- gt_frames = gt_frames[0]
61
-
62
-
63
- with torch.no_grad():
64
-
65
- if len(original_pixel_values.shape) == 5:
66
- pixel_values = original_pixel_values[0] #assuming batch size is 1
67
- else:
68
- pixel_values = original_pixel_values.repeat(num_frames, 1, 1, 1)
69
- pixel_values_normalized = pixel_values*0.5 + 0.5
70
- pixel_values_normalized = torch.clamp(pixel_values_normalized,0,1)
71
-
72
-
73
-
74
-
75
- video_frames_normalized = video_frames*0.5 + 0.5
76
- video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
77
- video_frames_normalized = video_frames_normalized.permute(1,0,2,3)
78
-
79
-
80
- gt_frames = torch.clamp(gt_frames,0,1)
81
- gt_frames = gt_frames.permute(1,0,2,3)
82
-
83
- #RESIZE images
84
- video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
85
- gt_frames = torch.nn.functional.interpolate(gt_frames, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
86
- pixel_values_normalized = torch.nn.functional.interpolate(pixel_values_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
87
-
88
- os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/videos"), exist_ok=True)
89
- videoio.videosave(os.path.join(
90
- val_save_dir,
91
- f"position_{focal_stack_num}/videos/{batch['name']}.mp4",
92
- ), video_frames_normalized.permute(0,2,3,1).cpu().numpy(), fps=5)
93
-
94
- #save images
95
- os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/images"), exist_ok=True)
96
- for i in range(num_frames):
97
- #use Pillow to save images
98
- img = Image.fromarray((video_frames_normalized[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
99
- #use index to assign icc profile to img
100
- if batch['icc_profile'] != "none":
101
- img.info['icc_profile'] = batch['icc_profile']
102
- path = os.path.join(val_save_dir, f"position_{focal_stack_num}/images/{batch['name']}_frame_{i}.png")
103
- print("Saving image to ", path)
104
- img.save(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/{batch['name']}_frame_{i}.png"))
105
- del video_frames
106
-
107
-
108
-