tedlasai commited on
Commit
30f9e7d
·
1 Parent(s): 18f31b4
Files changed (1) hide show
  1. app.py +47 -30
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import spaces
2
  from pathlib import Path
3
  import argparse
@@ -20,11 +21,11 @@ pipe, device = load_model(args)
20
  OUTPUT_DIR = Path("/tmp/output_stacks")
21
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
22
 
23
- NUM_FRAMES = 9 # frame_0.png ... frame_8.png
24
- LOOP_S = 0.3 # autoplay speed (s)
25
 
26
  @spaces.GPU(timeout=300, duration=80)
27
- def generate_vstack_from_image(image: Image.Image, input_focal_position: int, num_inference_steps: int):
28
  if image is None:
29
  raise gr.Error("Please upload an image first.")
30
 
@@ -37,35 +38,37 @@ def generate_vstack_from_image(image: Image.Image, input_focal_position: int, nu
37
 
38
  write_output(OUTPUT_DIR, output_frames, focal_stack_num, batch["icc_profile"])
39
 
 
40
  first_frame = OUTPUT_DIR / "frame_0.png"
 
 
 
41
  if not first_frame.exists():
42
  raise gr.Error("frame_0.png not found in output_dir")
43
 
44
- # Show first frame + reset slider
45
- return str(first_frame), gr.update(value=0)
46
 
47
  def show_frame(idx: int):
48
- path = OUTPUT_DIR / f"frame_{idx}.png"
49
  if not path.exists():
50
  return None
51
  return str(path)
52
 
53
- def advance_frame(idx: int, autoplay: bool):
54
- if not autoplay:
55
- return gr.update(value=int(idx)), gr.update()
56
 
57
- next_idx = (int(idx) + 1) % NUM_FRAMES
58
- next_path = OUTPUT_DIR / f"frame_{next_idx}.png"
59
- if not next_path.exists():
60
- return gr.update(value=int(idx)), gr.update()
 
 
61
 
62
- return gr.update(value=next_idx), str(next_path)
63
 
64
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
65
  gr.Markdown(
66
  """
67
  # 🖼️ ➜ 🎬 Generate Focal Stacks from a Single Image
68
- Generate a focal stack and scrub through frames. Toggle autoplay on the output panel.
69
  """
70
  )
71
 
@@ -79,6 +82,7 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
79
  maximum=8,
80
  step=1,
81
  value=4,
 
82
  )
83
 
84
  num_inference_steps = gr.Slider(
@@ -93,22 +97,36 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
93
  generate_btn = gr.Button("Generate stack", variant="primary")
94
 
95
  with gr.Column():
96
- frame_view = gr.Image(label="Frame viewer", type="filepath")
 
 
 
 
97
 
98
- frame_slider = gr.Slider(
99
- minimum=0,
100
- maximum=NUM_FRAMES - 1,
101
- step=1,
102
- value=0,
103
- label="Output Focus Position",
 
104
  )
105
 
106
- autoplay = gr.Checkbox(value=True, label="Autoplay")
 
 
 
 
 
 
 
 
 
107
 
108
  generate_btn.click(
109
- fn=generate_vstack_from_image,
110
  inputs=[image_in, input_focal_position, num_inference_steps],
111
- outputs=[frame_view, frame_slider],
112
  api_name="predict",
113
  )
114
 
@@ -118,11 +136,10 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
118
  outputs=frame_view,
119
  )
120
 
121
- timer = gr.Timer(LOOP_S)
122
- timer.tick(
123
- fn=advance_frame,
124
- inputs=[frame_slider, autoplay],
125
- outputs=[frame_slider, frame_view],
126
  )
127
 
128
  if __name__ == "__main__":
 
1
+ import os
2
  import spaces
3
  from pathlib import Path
4
  import argparse
 
21
  OUTPUT_DIR = Path("/tmp/output_stacks")
22
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
23
 
24
+ NUM_FRAMES = 9 # frame_0.png ... frame_8.png
25
+
26
 
27
  @spaces.GPU(timeout=300, duration=80)
28
+ def generate_outputs(image: Image.Image, input_focal_position: int, num_inference_steps: int):
29
  if image is None:
30
  raise gr.Error("Please upload an image first.")
31
 
 
38
 
39
  write_output(OUTPUT_DIR, output_frames, focal_stack_num, batch["icc_profile"])
40
 
41
+ video_path = OUTPUT_DIR / "stack.mp4"
42
  first_frame = OUTPUT_DIR / "frame_0.png"
43
+
44
+ if not video_path.exists():
45
+ raise gr.Error("stack.mp4 not found in output_dir")
46
  if not first_frame.exists():
47
  raise gr.Error("frame_0.png not found in output_dir")
48
 
49
+ return str(video_path), str(first_frame), gr.update(value=0)
50
+
51
 
52
  def show_frame(idx: int):
53
+ path = OUTPUT_DIR / f"frame_{int(idx)}.png"
54
  if not path.exists():
55
  return None
56
  return str(path)
57
 
 
 
 
58
 
59
+ def set_view_mode(mode: str):
60
+ show_video = (mode == "Video")
61
+ return (
62
+ gr.update(visible=show_video),
63
+ gr.update(visible=not show_video),
64
+ )
65
 
 
66
 
67
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
68
  gr.Markdown(
69
  """
70
  # 🖼️ ➜ 🎬 Generate Focal Stacks from a Single Image
71
+ Switch between **Video** (stack.mp4) and **Frames** (PNG slider).
72
  """
73
  )
74
 
 
82
  maximum=8,
83
  step=1,
84
  value=4,
85
+ interactive=True,
86
  )
87
 
88
  num_inference_steps = gr.Slider(
 
97
  generate_btn = gr.Button("Generate stack", variant="primary")
98
 
99
  with gr.Column():
100
+ view_mode = gr.Radio(
101
+ choices=["Video", "Frames"],
102
+ value="Video",
103
+ label="Output view",
104
+ )
105
 
106
+ # --- Video output ---
107
+ video_out = gr.Video(
108
+ label="Generated stack (stack.mp4)",
109
+ format="mp4",
110
+ autoplay=True,
111
+ loop=True,
112
+ visible=True,
113
  )
114
 
115
+ # --- Frames output (group) ---
116
+ with gr.Group(visible=False) as frames_group:
117
+ frame_view = gr.Image(label="Frame viewer", type="filepath")
118
+ frame_slider = gr.Slider(
119
+ minimum=0,
120
+ maximum=NUM_FRAMES - 1,
121
+ step=1,
122
+ value=0,
123
+ label="Output focal position",
124
+ )
125
 
126
  generate_btn.click(
127
+ fn=generate_outputs,
128
  inputs=[image_in, input_focal_position, num_inference_steps],
129
+ outputs=[video_out, frame_view, frame_slider],
130
  api_name="predict",
131
  )
132
 
 
136
  outputs=frame_view,
137
  )
138
 
139
+ view_mode.change(
140
+ fn=set_view_mode,
141
+ inputs=view_mode,
142
+ outputs=[video_out, frames_group],
 
143
  )
144
 
145
  if __name__ == "__main__":