Spaces:
Running on Zero
Running on Zero
initial
Browse files- .gitattributes +1 -0
- app.py +365 -0
- demo_utils.py +579 -0
- environment.yaml +29 -0
- example_inputs/aerial.mp4 +3 -0
- example_inputs/aerial_thumb.png +0 -0
- example_inputs/doorbell.mp4 +3 -0
- example_inputs/doorbell_thumb.png +0 -0
- example_inputs/tomjerry.mp4 +3 -0
- example_inputs/tomjerry_thumb.png +0 -0
- requirements.txt +15 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import gc
|
| 6 |
+
from demo_utils import load_model, process_video, save_video, image_to_video
|
| 7 |
+
import av
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import spaces
|
| 13 |
+
ZEROGPU_AVAILABLE = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
ZEROGPU_AVAILABLE = False
|
| 16 |
+
print("Warning: spaces module not available. Running without ZeroGPU support.")
|
| 17 |
+
|
| 18 |
+
model_cache = {}
|
| 19 |
+
|
| 20 |
+
def get_model(device):
|
| 21 |
+
if device not in model_cache:
|
| 22 |
+
model_cache[device] = load_model(device=device)
|
| 23 |
+
return model_cache[device]
|
| 24 |
+
|
| 25 |
+
device = "cuda" if torch.cuda.is_available() or ZEROGPU_AVAILABLE else "cpu"
|
| 26 |
+
|
| 27 |
+
def cleanup_gpu():
|
| 28 |
+
"""Clean up GPU memory."""
|
| 29 |
+
gc.collect()
|
| 30 |
+
if torch.cuda.is_available():
|
| 31 |
+
torch.cuda.empty_cache()
|
| 32 |
+
torch.cuda.synchronize()
|
| 33 |
+
|
| 34 |
+
def extract_metadata(file):
|
| 35 |
+
if file is None:
|
| 36 |
+
return "", None, None, None, None, None
|
| 37 |
+
|
| 38 |
+
file_extension = os.path.splitext(file.name)[1].lower()
|
| 39 |
+
is_image = file_extension in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
|
| 40 |
+
|
| 41 |
+
if is_image:
|
| 42 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video:
|
| 43 |
+
tmp_path = tmp_video.name
|
| 44 |
+
|
| 45 |
+
metadata = image_to_video(file.name, tmp_path, fps=1.0)
|
| 46 |
+
|
| 47 |
+
total_frames = metadata['frames']
|
| 48 |
+
fps = metadata['fps']
|
| 49 |
+
original_height = metadata['height']
|
| 50 |
+
original_width = metadata['width']
|
| 51 |
+
info_text = f"{original_width}×{original_height} | Image (1 frame)"
|
| 52 |
+
else:
|
| 53 |
+
tmp_path = file.name
|
| 54 |
+
|
| 55 |
+
container = av.open(tmp_path)
|
| 56 |
+
video_stream = container.streams.video[0]
|
| 57 |
+
total_frames = video_stream.frames
|
| 58 |
+
fps = float(video_stream.average_rate)
|
| 59 |
+
original_height = video_stream.height
|
| 60 |
+
original_width = video_stream.width
|
| 61 |
+
container.close()
|
| 62 |
+
info_text = f"{original_width}×{original_height} | {total_frames} frames @ {fps:.1f} FPS"
|
| 63 |
+
|
| 64 |
+
return info_text, tmp_path, total_frames, fps, original_width, original_height
|
| 65 |
+
|
| 66 |
+
def handle_file_upload(file):
|
| 67 |
+
metadata = extract_metadata(file)
|
| 68 |
+
|
| 69 |
+
if metadata[1] is None:
|
| 70 |
+
return "", None, None
|
| 71 |
+
|
| 72 |
+
info_text, tmp_path, total_frames, fps, original_width, original_height = metadata
|
| 73 |
+
return info_text, metadata, fps
|
| 74 |
+
|
| 75 |
+
def _process_video_impl(file_info, gazing_ratio, task_loss_requirement, output_fps, progress=None):
|
| 76 |
+
if file_info is None:
|
| 77 |
+
return None, None, None, None, None, None, None, "No file uploaded"
|
| 78 |
+
|
| 79 |
+
_, tmp_path, total_frames, fps, _, _ = file_info
|
| 80 |
+
|
| 81 |
+
if tmp_path is None:
|
| 82 |
+
return None, None, None, None, None, None, None, "Invalid file"
|
| 83 |
+
|
| 84 |
+
# Yield initial status
|
| 85 |
+
yield None, None, None, None, None, None, None, "Loading model..."
|
| 86 |
+
|
| 87 |
+
if progress:
|
| 88 |
+
progress(0.0, desc="Loading model...")
|
| 89 |
+
setup = get_model(device)
|
| 90 |
+
|
| 91 |
+
yield None, None, None, None, None, None, None, "Processing video..."
|
| 92 |
+
|
| 93 |
+
if progress:
|
| 94 |
+
progress(0.1, desc="Processing video...")
|
| 95 |
+
|
| 96 |
+
status_messages = []
|
| 97 |
+
|
| 98 |
+
def update_progress(pct, msg):
|
| 99 |
+
if progress:
|
| 100 |
+
progress(pct, desc=msg)
|
| 101 |
+
status_messages.append(msg)
|
| 102 |
+
|
| 103 |
+
# Convert UI gazing ratio to model gazing ratio
|
| 104 |
+
# UI: ranges from 1/196 to 265/196 (effective patches per frame / 196)
|
| 105 |
+
# Model: needs value * (196/265) to get actual gazing ratio
|
| 106 |
+
model_gazing_ratio = gazing_ratio * (196 / 265)
|
| 107 |
+
|
| 108 |
+
for results in process_video(
|
| 109 |
+
tmp_path,
|
| 110 |
+
setup,
|
| 111 |
+
gazing_ratio=model_gazing_ratio,
|
| 112 |
+
task_loss_requirement=task_loss_requirement,
|
| 113 |
+
progress_callback=update_progress,
|
| 114 |
+
spatial_batch_size=2 # Process 4 spatial chunks at a time to avoid OOM
|
| 115 |
+
):
|
| 116 |
+
if status_messages:
|
| 117 |
+
yield None, None, None, None, None, None, None, status_messages[-1]
|
| 118 |
+
|
| 119 |
+
yield None, None, None, None, None, None, None, "Saving output videos..."
|
| 120 |
+
|
| 121 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 122 |
+
original_path = os.path.join(tmpdir, "original.mp4")
|
| 123 |
+
gazing_path = os.path.join(tmpdir, "gazing.mp4")
|
| 124 |
+
recon_path = os.path.join(tmpdir, "reconstruction.mp4")
|
| 125 |
+
scales_stitch_path = os.path.join(tmpdir, "scales_stitch.mp4")
|
| 126 |
+
|
| 127 |
+
# Use output_fps if specified, otherwise use original video fps
|
| 128 |
+
fps_to_use = output_fps if output_fps is not None else results['fps']
|
| 129 |
+
|
| 130 |
+
save_video(results['original_frames'], original_path, fps_to_use)
|
| 131 |
+
save_video(results['gazing_frames'], gazing_path, fps_to_use)
|
| 132 |
+
save_video(results['reconstruction_frames'], recon_path, fps_to_use)
|
| 133 |
+
save_video(results['scales_stitch_frames'], scales_stitch_path, fps_to_use)
|
| 134 |
+
|
| 135 |
+
with open(original_path, "rb") as f:
|
| 136 |
+
original_data = f.read()
|
| 137 |
+
with open(gazing_path, "rb") as f:
|
| 138 |
+
gazing_data = f.read()
|
| 139 |
+
with open(recon_path, "rb") as f:
|
| 140 |
+
recon_data = f.read()
|
| 141 |
+
with open(scales_stitch_path, "rb") as f:
|
| 142 |
+
scales_stitch_data = f.read()
|
| 143 |
+
|
| 144 |
+
original_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
| 145 |
+
original_file.write(original_data)
|
| 146 |
+
original_file.close()
|
| 147 |
+
|
| 148 |
+
gazing_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
| 149 |
+
gazing_file.write(gazing_data)
|
| 150 |
+
gazing_file.close()
|
| 151 |
+
|
| 152 |
+
recon_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
| 153 |
+
recon_file.write(recon_data)
|
| 154 |
+
recon_file.close()
|
| 155 |
+
|
| 156 |
+
scales_stitch_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
| 157 |
+
scales_stitch_file.write(scales_stitch_data)
|
| 158 |
+
scales_stitch_file.close()
|
| 159 |
+
|
| 160 |
+
gazing_pct_text = f"{results['gazing_pct']:.2%}"
|
| 161 |
+
gazing_tokens_text = f"{results['total_gazing_tokens']:,}"
|
| 162 |
+
total_tokens_text = f"{results['total_possible_tokens']:,}"
|
| 163 |
+
|
| 164 |
+
yield (
|
| 165 |
+
gazing_pct_text,
|
| 166 |
+
gazing_tokens_text,
|
| 167 |
+
total_tokens_text,
|
| 168 |
+
original_file.name,
|
| 169 |
+
gazing_file.name,
|
| 170 |
+
recon_file.name,
|
| 171 |
+
scales_stitch_file.name,
|
| 172 |
+
"Processing complete!"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if ZEROGPU_AVAILABLE:
|
| 176 |
+
process_video_ui = spaces.GPU(duration=120)(_process_video_impl)
|
| 177 |
+
else:
|
| 178 |
+
process_video_ui = _process_video_impl
|
| 179 |
+
|
| 180 |
+
def extract_first_frame_thumbnail(video_path, output_path, size=(200, 200), force=False):
|
| 181 |
+
"""Extract first frame from video and save as thumbnail with fixed aspect ratio."""
|
| 182 |
+
if os.path.exists(output_path) and not force:
|
| 183 |
+
return
|
| 184 |
+
container = av.open(video_path)
|
| 185 |
+
for frame in container.decode(video=0):
|
| 186 |
+
img = frame.to_image()
|
| 187 |
+
# Crop to center square first, then resize
|
| 188 |
+
width, height = img.size
|
| 189 |
+
min_dim = min(width, height)
|
| 190 |
+
left = (width - min_dim) // 2
|
| 191 |
+
top = (height - min_dim) // 2
|
| 192 |
+
img_cropped = img.crop((left, top, left + min_dim, top + min_dim))
|
| 193 |
+
img_resized = img_cropped.resize(size, Image.LANCZOS)
|
| 194 |
+
img_resized.save(output_path)
|
| 195 |
+
break
|
| 196 |
+
container.close()
|
| 197 |
+
|
| 198 |
+
# Generate thumbnails for example videos
|
| 199 |
+
example_videos = [
|
| 200 |
+
"example_inputs/aerial.mp4",
|
| 201 |
+
"example_inputs/doorbell.mp4",
|
| 202 |
+
"example_inputs/tomjerry.mp4",
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
for video_path in example_videos:
|
| 206 |
+
if os.path.exists(video_path):
|
| 207 |
+
thumb_path = video_path.replace('.mp4', '_thumb.png')
|
| 208 |
+
# Force regeneration with square aspect ratio at 100x100 to match gallery height
|
| 209 |
+
extract_first_frame_thumbnail(video_path, thumb_path, size=(100, 100), force=True)
|
| 210 |
+
|
| 211 |
+
# Load thumbnails as numpy arrays
|
| 212 |
+
aerial_thumb_img = np.array(Image.open("example_inputs/aerial_thumb.png"))
|
| 213 |
+
doorbell_thumb_img = np.array(Image.open("example_inputs/doorbell_thumb.png"))
|
| 214 |
+
tomjerry_thumb_img = np.array(Image.open("example_inputs/tomjerry_thumb.png"))
|
| 215 |
+
|
| 216 |
+
with gr.Blocks(title="AutoGaze Demo", delete_cache=(86400, 86400)) as demo:
|
| 217 |
+
gr.Markdown("# AutoGaze Official Demo")
|
| 218 |
+
gr.Markdown("## **Attend Before Attention: Efficient and Scalable Video Understanding via Autoregressive Gazing**")
|
| 219 |
+
gr.Markdown("""
|
| 220 |
+
<div style="text-align: left; margin: 10px 0; font-size: 1.2em; font-weight: 600;">
|
| 221 |
+
📄 <a href="https://arxiv.org/abs/PLACEHOLDER" target="_blank" style="text-decoration: none; color: inherit;">Paper</a> 🌐 <a href="https://placeholder-website.com" target="_blank" style="text-decoration: none; color: inherit;">Project Website</a>
|
| 222 |
+
</div>
|
| 223 |
+
""")
|
| 224 |
+
|
| 225 |
+
file_metadata = gr.State()
|
| 226 |
+
|
| 227 |
+
with gr.Row():
|
| 228 |
+
with gr.Column(scale=2):
|
| 229 |
+
uploaded_file = gr.File(
|
| 230 |
+
label="Upload Video or Image",
|
| 231 |
+
file_types=["video", "image"]
|
| 232 |
+
)
|
| 233 |
+
with gr.Column(scale=1):
|
| 234 |
+
file_info = gr.Textbox(label="File Info", interactive=False)
|
| 235 |
+
process_button = gr.Button("Process Video", variant="primary")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def load_example_video(evt: gr.SelectData):
|
| 239 |
+
video_map = {
|
| 240 |
+
0: "example_inputs/aerial.mp4",
|
| 241 |
+
1: "example_inputs/doorbell.mp4",
|
| 242 |
+
2: "example_inputs/tomjerry.mp4",
|
| 243 |
+
}
|
| 244 |
+
return video_map[evt.index]
|
| 245 |
+
|
| 246 |
+
with gr.Row():
|
| 247 |
+
with gr.Column(scale=1):
|
| 248 |
+
gr.Markdown("### Example Videos - Click Thumbnail to Load")
|
| 249 |
+
example_gallery = gr.Gallery(
|
| 250 |
+
value=[
|
| 251 |
+
(aerial_thumb_img, "aerial.mp4"),
|
| 252 |
+
(doorbell_thumb_img, "doorbell.mp4"),
|
| 253 |
+
(tomjerry_thumb_img, "tomjerry.mp4"),
|
| 254 |
+
],
|
| 255 |
+
label="",
|
| 256 |
+
show_label=False,
|
| 257 |
+
columns=3,
|
| 258 |
+
rows=1,
|
| 259 |
+
height=200,
|
| 260 |
+
object_fit="contain",
|
| 261 |
+
allow_preview=False
|
| 262 |
+
)
|
| 263 |
+
gr.Markdown("### Settings")
|
| 264 |
+
|
| 265 |
+
with gr.Accordion("Output Settings", open=True):
|
| 266 |
+
fps_slider = gr.Number(
|
| 267 |
+
label="Output FPS",
|
| 268 |
+
value=None,
|
| 269 |
+
minimum=1,
|
| 270 |
+
maximum=120,
|
| 271 |
+
info="Frames per second for displaying output videos (only affects playback speed)"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
with gr.Accordion("Model Parameters", open=True):
|
| 275 |
+
gazing_ratio_slider = gr.Slider(
|
| 276 |
+
label="Gazing Ratio",
|
| 277 |
+
minimum=round(1/196, 2),
|
| 278 |
+
maximum=round(265/196, 2),
|
| 279 |
+
step=0.01,
|
| 280 |
+
value=0.75,
|
| 281 |
+
info="Max fraction of patches to gaze at per frame"
|
| 282 |
+
)
|
| 283 |
+
task_loss_slider = gr.Slider(
|
| 284 |
+
label="Task Loss Requirement",
|
| 285 |
+
minimum=0.0,
|
| 286 |
+
maximum=1.5,
|
| 287 |
+
step=0.05,
|
| 288 |
+
value=0.6,
|
| 289 |
+
info="Reconstruction loss threshold"
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
with gr.Accordion("FAQ", open=False):
|
| 293 |
+
gr.Markdown("""
|
| 294 |
+
**What file formats are supported?**
|
| 295 |
+
|
| 296 |
+
The app supports common video formats (MP4, AVI, MOV, etc.) and image formats (JPG, PNG, etc.).
|
| 297 |
+
|
| 298 |
+
**What is the Gazing Ratio?**
|
| 299 |
+
|
| 300 |
+
The gazing ratio explicitly controls how many patches the model looks at per frame. Higher values mean more patches are selected. The range extends to past 1.0 because of multi-scale gazing; if all patches at all scales are selected, the ratio can reach up to 1.35.
|
| 301 |
+
|
| 302 |
+
**What is Task Loss Requirement?**
|
| 303 |
+
|
| 304 |
+
This threshold determines when the model stops gazing at a frame, based on the predicted reconstruction loss from the current gazed patches. Lower = more gazing, higher = less gazing.
|
| 305 |
+
|
| 306 |
+
**How do Gazing Ratio and Task Loss interact?**
|
| 307 |
+
|
| 308 |
+
These two parameters separately control the number of gazed patches in an image/video. This demo will take the stricter of the two requirements when determining how many patches to gaze at. For example, if the gazing ratio suggests gazing at 15% of patches, but the task loss requirement is met after only 7% patches, then only 7% patches will be gazed at. To only use one of the two parameters, set the other to its maximum value.
|
| 309 |
+
""")
|
| 310 |
+
|
| 311 |
+
with gr.Column(scale=2):
|
| 312 |
+
gr.Markdown("### Results")
|
| 313 |
+
|
| 314 |
+
status_text = gr.Markdown("Ready")
|
| 315 |
+
|
| 316 |
+
with gr.Row():
|
| 317 |
+
gazing_pct = gr.Textbox(label="Gazing %", interactive=False)
|
| 318 |
+
gazing_tokens = gr.Textbox(label="# Gazed Patches", interactive=False)
|
| 319 |
+
total_tokens = gr.Textbox(label="Total Patches", interactive=False)
|
| 320 |
+
|
| 321 |
+
with gr.Row():
|
| 322 |
+
original_video = gr.Video(label="Original", autoplay=False, loop=True)
|
| 323 |
+
gazing_video = gr.Video(label="Gazing Pattern (all scales)", autoplay=False, loop=True)
|
| 324 |
+
reconstruction_video = gr.Video(label="Reconstruction", autoplay=False, loop=True)
|
| 325 |
+
|
| 326 |
+
with gr.Row():
|
| 327 |
+
scales_stitch_video = gr.Video(label="Gazing Pattern (individual scales)", autoplay=False, loop=True)
|
| 328 |
+
|
| 329 |
+
example_gallery.select(load_example_video, outputs=uploaded_file)
|
| 330 |
+
uploaded_file.change(
|
| 331 |
+
fn=handle_file_upload,
|
| 332 |
+
inputs=[uploaded_file],
|
| 333 |
+
outputs=[file_info, file_metadata, fps_slider]
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
process_button.click(
|
| 337 |
+
fn=process_video_ui,
|
| 338 |
+
inputs=[file_metadata, gazing_ratio_slider, task_loss_slider, fps_slider],
|
| 339 |
+
outputs=[
|
| 340 |
+
gazing_pct,
|
| 341 |
+
gazing_tokens,
|
| 342 |
+
total_tokens,
|
| 343 |
+
original_video,
|
| 344 |
+
gazing_video,
|
| 345 |
+
reconstruction_video,
|
| 346 |
+
scales_stitch_video,
|
| 347 |
+
status_text
|
| 348 |
+
]
|
| 349 |
+
).then(
|
| 350 |
+
fn=cleanup_gpu,
|
| 351 |
+
inputs=None,
|
| 352 |
+
outputs=None
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Clean up GPU memory when user disconnects
|
| 356 |
+
demo.unload(cleanup_gpu)
|
| 357 |
+
|
| 358 |
+
# Clear any cached models and free GPU memory at app startup
|
| 359 |
+
print("Clearing model cache and GPU memory at startup...")
|
| 360 |
+
model_cache.clear()
|
| 361 |
+
cleanup_gpu()
|
| 362 |
+
print("Startup cleanup complete.")
|
| 363 |
+
|
| 364 |
+
if __name__ == "__main__":
|
| 365 |
+
demo.launch(share=True)
|
demo_utils.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
import av
|
| 7 |
+
import imageio
|
| 8 |
+
from transformers import VivitImageProcessor
|
| 9 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
|
| 13 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'gengaze'))
|
| 14 |
+
from autogaze.models.autogaze import AutoGaze
|
| 15 |
+
from autogaze.datasets.video_utils import read_video_pyav, transform_video_for_pytorch
|
| 16 |
+
from autogaze.tasks.video_mae_reconstruction import VideoMAEReconstruction
|
| 17 |
+
from autogaze.utils import UnNormalize
|
| 18 |
+
from tqdm import trange
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import spaces
|
| 22 |
+
ZEROGPU_AVAILABLE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
ZEROGPU_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def image_to_video(image_path, output_path, fps):
|
| 28 |
+
"""
|
| 29 |
+
Convert a single image to a single-frame video file.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
image_path: Path to input image
|
| 33 |
+
output_path: Path to output video file
|
| 34 |
+
fps: Frame rate for the video
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Dictionary with video metadata (width, height, frames, fps)
|
| 38 |
+
"""
|
| 39 |
+
img = Image.open(image_path)
|
| 40 |
+
if img.mode != 'RGB':
|
| 41 |
+
img = img.convert('RGB')
|
| 42 |
+
|
| 43 |
+
img_array = np.array(img)
|
| 44 |
+
|
| 45 |
+
with imageio.get_writer(output_path, fps=fps, format='FFMPEG', codec='libx264', pixelformat='yuv420p') as writer:
|
| 46 |
+
writer.append_data(img_array)
|
| 47 |
+
|
| 48 |
+
return {
|
| 49 |
+
'width': img_array.shape[1],
|
| 50 |
+
'height': img_array.shape[0],
|
| 51 |
+
'frames': 1,
|
| 52 |
+
'fps': fps
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_model(device='cuda'):
|
| 57 |
+
print("Loading AutoGaze model from HuggingFace...")
|
| 58 |
+
model = AutoGaze.from_pretrained("bfshi/AutoGaze")
|
| 59 |
+
model = model.to(device)
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
transform = VivitImageProcessor.from_pretrained(
|
| 63 |
+
"facebook/vit-mae-large",
|
| 64 |
+
size=model.scales[-1],
|
| 65 |
+
crop_size=model.scales[-1]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
unnorm = UnNormalize(
|
| 69 |
+
mean=transform.image_mean,
|
| 70 |
+
std=transform.image_std,
|
| 71 |
+
rescale_factor=transform.rescale_factor
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
print("Loading VideoMAE model from HuggingFace...")
|
| 75 |
+
scales_str = '+'.join(map(str, model.scales))
|
| 76 |
+
recon_model_config = OmegaConf.create({
|
| 77 |
+
'scale_embed': True,
|
| 78 |
+
'max_num_frames': 256,
|
| 79 |
+
'time_embed': True,
|
| 80 |
+
'causal': True,
|
| 81 |
+
'loss_type': 'l1+dinov2_reg+siglip2',
|
| 82 |
+
'loss_weights': '1',
|
| 83 |
+
'l1_loss_config': {},
|
| 84 |
+
'dinov2_reg_loss_config': {
|
| 85 |
+
'model': 'facebook/dinov2-with-registers-base'
|
| 86 |
+
},
|
| 87 |
+
'siglip2_loss_config': {
|
| 88 |
+
'model': 'google/siglip2-base-patch16-224'
|
| 89 |
+
}
|
| 90 |
+
})
|
| 91 |
+
task = VideoMAEReconstruction(
|
| 92 |
+
recon_model='facebook/vit-mae-large',
|
| 93 |
+
recon_model_config=recon_model_config,
|
| 94 |
+
scales=scales_str,
|
| 95 |
+
recon_sample_rate=1,
|
| 96 |
+
attn_mode='sdpa'
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Load fine-tuned weights from HuggingFace
|
| 100 |
+
from huggingface_hub import hf_hub_download
|
| 101 |
+
checkpoint_path = hf_hub_download(repo_id="bfshi/VideoMAE_AutoGaze", filename="videomae.pt")
|
| 102 |
+
print(f"Loading VideoMAE checkpoint from {checkpoint_path}...")
|
| 103 |
+
task_sd = torch.load(checkpoint_path, map_location='cpu')
|
| 104 |
+
task_sd = {k.replace('module.mae.', ''): v for k, v in task_sd.items()}
|
| 105 |
+
task.mae.load_state_dict(task_sd, strict=True)
|
| 106 |
+
print("Loaded VideoMAE checkpoint from HuggingFace")
|
| 107 |
+
|
| 108 |
+
task = task.to(device)
|
| 109 |
+
task.eval()
|
| 110 |
+
|
| 111 |
+
return {
|
| 112 |
+
'model': model,
|
| 113 |
+
'task': task,
|
| 114 |
+
'unnorm': unnorm,
|
| 115 |
+
'scales': model.scales,
|
| 116 |
+
'transform': transform,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def process_video(video_path, setup, gazing_ratio=0.75, task_loss_requirement=0.6, progress_callback=None, spatial_batch_size=16):
|
| 121 |
+
"""
|
| 122 |
+
Process a video file with AutoGaze using chunking for any resolution/duration.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
video_path: Path to video file
|
| 126 |
+
setup: Dictionary with model, task, unnorm, scales, transform
|
| 127 |
+
gazing_ratio: Maximum percentage of patches to gaze per frame
|
| 128 |
+
task_loss_requirement: Reconstruction loss threshold
|
| 129 |
+
progress_callback: Optional callback function for progress updates
|
| 130 |
+
|
| 131 |
+
Yields:
|
| 132 |
+
Dictionary with original frames, gazing frames, reconstruction frames, and statistics
|
| 133 |
+
"""
|
| 134 |
+
model = setup['model']
|
| 135 |
+
task = setup['task']
|
| 136 |
+
transform = setup['transform']
|
| 137 |
+
device = next(model.parameters()).device
|
| 138 |
+
if device == 'cuda':
|
| 139 |
+
torch.cuda.empty_cache()
|
| 140 |
+
|
| 141 |
+
container = av.open(video_path)
|
| 142 |
+
video_stream = container.streams.video[0]
|
| 143 |
+
total_frames_available = video_stream.frames
|
| 144 |
+
fps = float(video_stream.average_rate)
|
| 145 |
+
container.close()
|
| 146 |
+
|
| 147 |
+
container = av.open(video_path)
|
| 148 |
+
sample_indices = list(range(total_frames_available))
|
| 149 |
+
video = read_video_pyav(container=container, indices=sample_indices) # (T, H, W, 3) numpy array
|
| 150 |
+
container.close()
|
| 151 |
+
|
| 152 |
+
# Keep video on CPU for preprocessing to save GPU memory
|
| 153 |
+
video_tensor = torch.from_numpy(video).float() # (T, H, W, 3)
|
| 154 |
+
video_tensor = video_tensor / 255.0 # Normalize to [0, 1]
|
| 155 |
+
video_tensor = video_tensor.permute(0, 3, 1, 2) # (T, C, H, W)
|
| 156 |
+
T, C, H, W = video_tensor.shape
|
| 157 |
+
if T > 200:
|
| 158 |
+
print(f'Video has {T} frames, which may require significant GPU memory. Decreasing spatial_batch_size to 2.')
|
| 159 |
+
spatial_batch_size //= 2
|
| 160 |
+
|
| 161 |
+
# Clone for later visualization (keep on CPU)
|
| 162 |
+
video_tensor_original = video_tensor.clone()
|
| 163 |
+
|
| 164 |
+
# Pad video to be divisible by 224x224 and 16 frames
|
| 165 |
+
pad_t = (16 - T % 16) % 16
|
| 166 |
+
pad_h = (224 - H % 224) % 224
|
| 167 |
+
pad_w = (224 - W % 224) % 224
|
| 168 |
+
|
| 169 |
+
if pad_t > 0 or pad_h > 0 or pad_w > 0:
|
| 170 |
+
video_tensor = F.pad(video_tensor, (0, pad_w, 0, pad_h, 0, 0, 0, pad_t))
|
| 171 |
+
|
| 172 |
+
# Chunk video into 16-frame, 224x224 chunks (following QUICK_START.md)
|
| 173 |
+
video_tensor = video_tensor.unsqueeze(0) # 1 * T * C * H * W
|
| 174 |
+
|
| 175 |
+
# Calculate chunking dimensions
|
| 176 |
+
nt = (T + pad_t) // 16
|
| 177 |
+
nh = (H + pad_h) // 224
|
| 178 |
+
nw = (W + pad_w) // 224
|
| 179 |
+
num_spatial_chunks = nh * nw
|
| 180 |
+
num_chunks = nt * num_spatial_chunks
|
| 181 |
+
|
| 182 |
+
# Chunk into (num_chunks, 16, C, 224, 224)
|
| 183 |
+
video_chunks = rearrange(video_tensor, 'B (nt t) C (nh h) (nw w) -> (B nt nh nw) t C h w', t=16, h=224, w=224)
|
| 184 |
+
|
| 185 |
+
print(f"Video chunked into {num_chunks} chunks ({nt} temporal x {num_spatial_chunks} spatial) of shape (16, {C}, 224, 224). Original shape: ({T}, {C}, {H}, {W})")
|
| 186 |
+
|
| 187 |
+
# Apply VivitImageProcessor normalization to chunks
|
| 188 |
+
# Rearrange chunks to process all frames: (num_chunks, 16, C, H, W) -> (num_chunks * 16, C, H, W)
|
| 189 |
+
chunks_flat = rearrange(video_chunks, 'b t c h w -> (b t) c h w')
|
| 190 |
+
|
| 191 |
+
# Apply normalization using VivitImageProcessor's mean and std (on CPU)
|
| 192 |
+
mean = torch.tensor(transform.image_mean).view(1, 3, 1, 1)
|
| 193 |
+
std = torch.tensor(transform.image_std).view(1, 3, 1, 1)
|
| 194 |
+
chunks_flat = (chunks_flat - mean) / std
|
| 195 |
+
|
| 196 |
+
video_chunks = rearrange(chunks_flat, '(b t) c h w -> b t c h w', b=num_chunks, t=16)
|
| 197 |
+
video_chunks = rearrange(video_chunks, '(ns nt) t c h w -> ns nt t c h w', ns=num_spatial_chunks, nt=nt)
|
| 198 |
+
|
| 199 |
+
# Keep video_chunks on CPU - only move mini-batches to GPU as needed
|
| 200 |
+
print(f'video_chunks shape (spatial, temporal, frames, C, H, W): {video_chunks.shape}')
|
| 201 |
+
|
| 202 |
+
del video_tensor, chunks_flat, mean, std
|
| 203 |
+
|
| 204 |
+
with torch.inference_mode():
|
| 205 |
+
# Process spatial locations in mini-batches (keep all temporal chunks together per spatial location)
|
| 206 |
+
num_spatial_batches = (num_spatial_chunks + spatial_batch_size - 1) // spatial_batch_size
|
| 207 |
+
|
| 208 |
+
all_gaze_outputs = []
|
| 209 |
+
total_gazing_tokens = 0
|
| 210 |
+
|
| 211 |
+
for batch_idx in range(num_spatial_batches):
|
| 212 |
+
start_idx = batch_idx * spatial_batch_size
|
| 213 |
+
end_idx = min(start_idx + spatial_batch_size, num_spatial_chunks)
|
| 214 |
+
batch_size = end_idx - start_idx
|
| 215 |
+
|
| 216 |
+
gazing_pct = int(((batch_idx + 1) / num_spatial_batches) * 100)
|
| 217 |
+
if progress_callback:
|
| 218 |
+
progress_callback(0.1 + 0.4 * (batch_idx / num_spatial_batches), f"Gazing progress: {gazing_pct}%")
|
| 219 |
+
yield None
|
| 220 |
+
|
| 221 |
+
# Extract mini-batch from CPU and move to GPU: (batch_size, nt, 16, C, H, W)
|
| 222 |
+
spatial_batch = video_chunks[start_idx:end_idx].to(device)
|
| 223 |
+
# Flatten to (batch_size * nt, 16, C, H, W) for model
|
| 224 |
+
spatial_batch = rearrange(spatial_batch, 'bs nt t c h w -> (bs nt) t c h w')
|
| 225 |
+
print(f'Processing spatial batch {batch_idx+1}/{num_spatial_batches} with {batch_size} spatial locations x {nt} temporal = {spatial_batch.shape[0]} chunks')
|
| 226 |
+
|
| 227 |
+
# Run AutoGaze on this mini-batch
|
| 228 |
+
batch_gaze_output = model({"video": spatial_batch}, gazing_ratio=gazing_ratio, task_loss_requirement=task_loss_requirement)
|
| 229 |
+
|
| 230 |
+
# Free GPU memory after forward pass
|
| 231 |
+
del spatial_batch
|
| 232 |
+
|
| 233 |
+
# Count gazing tokens for this batch
|
| 234 |
+
if_padded = batch_gaze_output.get('if_padded_gazing')
|
| 235 |
+
if if_padded is not None:
|
| 236 |
+
total_gazing_tokens += (~if_padded).sum().item()
|
| 237 |
+
else:
|
| 238 |
+
total_gazing_tokens += (batch_gaze_output['gazing_pos'] < (196 * 16)).sum().item()
|
| 239 |
+
|
| 240 |
+
# Store the output
|
| 241 |
+
all_gaze_outputs.append(batch_gaze_output)
|
| 242 |
+
if torch.cuda.is_available():
|
| 243 |
+
torch.cuda.empty_cache()
|
| 244 |
+
|
| 245 |
+
print("Merging mini-batch results...")
|
| 246 |
+
|
| 247 |
+
# Find max sequence length across all mini-batches
|
| 248 |
+
max_seq_len = max(out['gazing_pos'].shape[1] for out in all_gaze_outputs)
|
| 249 |
+
|
| 250 |
+
# Pad gazing_pos and if_padded_gazing to same length (they have variable seq length)
|
| 251 |
+
# gazing_mask doesn't need padding since all chunks have same shape
|
| 252 |
+
padded_gazing_pos = []
|
| 253 |
+
padded_if_padded_gazing = []
|
| 254 |
+
|
| 255 |
+
for out in all_gaze_outputs:
|
| 256 |
+
seq_len = out['gazing_pos'].shape[1]
|
| 257 |
+
pad_len = max_seq_len - seq_len
|
| 258 |
+
|
| 259 |
+
# Pad gazing_pos with zeros
|
| 260 |
+
padded_pos = F.pad(out['gazing_pos'], (0, pad_len), value=0)
|
| 261 |
+
padded_gazing_pos.append(padded_pos)
|
| 262 |
+
|
| 263 |
+
# Pad if_padded_gazing and mark new positions as True (padded)
|
| 264 |
+
if 'if_padded_gazing' in out:
|
| 265 |
+
padded_if_pad = F.pad(out['if_padded_gazing'], (0, pad_len), value=True)
|
| 266 |
+
padded_if_padded_gazing.append(padded_if_pad)
|
| 267 |
+
|
| 268 |
+
# Store num_gazing_each_frame per mini-batch for later per-chunk extraction
|
| 269 |
+
num_gazing_each_frame_list = [out['num_gazing_each_frame'] for out in all_gaze_outputs]
|
| 270 |
+
batch_sizes = [out['gazing_pos'].shape[0] for out in all_gaze_outputs]
|
| 271 |
+
|
| 272 |
+
gaze_output = {
|
| 273 |
+
'gazing_pos': torch.cat(padded_gazing_pos, dim=0),
|
| 274 |
+
'gazing_mask': [torch.cat([out['gazing_mask'][i] for out in all_gaze_outputs], dim=0) for i in range(4)],
|
| 275 |
+
'num_gazing_each_frame_list': num_gazing_each_frame_list, # List of values per mini-batch
|
| 276 |
+
'batch_sizes': batch_sizes, # Track which chunks came from which mini-batch
|
| 277 |
+
'frame_sampling_rate': all_gaze_outputs[0]['frame_sampling_rate'],
|
| 278 |
+
'num_vision_tokens_each_frame': all_gaze_outputs[0]['num_vision_tokens_each_frame'],
|
| 279 |
+
}
|
| 280 |
+
if len(padded_if_padded_gazing) > 0:
|
| 281 |
+
gaze_output['if_padded_gazing'] = torch.cat(padded_if_padded_gazing, dim=0)
|
| 282 |
+
|
| 283 |
+
# Clean up mini-batch outputs
|
| 284 |
+
del all_gaze_outputs
|
| 285 |
+
|
| 286 |
+
total_possible_tokens = 196 * 16 * num_chunks
|
| 287 |
+
|
| 288 |
+
# Extract gazing masks for later visualization (already in batched form)
|
| 289 |
+
gazing_masks_batched = gaze_output['gazing_mask'] # List of 4 scales, each (num_chunks, 16, num_patches)
|
| 290 |
+
|
| 291 |
+
# Flatten video_chunks back to (num_chunks, 16, C, H, W) for reconstruction
|
| 292 |
+
video_chunks_flat = rearrange(video_chunks, 'ns nt t c h w -> (ns nt) t c h w').cpu()
|
| 293 |
+
|
| 294 |
+
# Pre-allocate reconstruction tensor on CPU to avoid memory accumulation
|
| 295 |
+
total_frames = num_chunks * 16
|
| 296 |
+
C = video_chunks_flat.shape[2]
|
| 297 |
+
reconstruction_chunks = torch.zeros((total_frames, C, 224, 224), dtype=torch.float32)
|
| 298 |
+
frame_idx_counter = 0
|
| 299 |
+
|
| 300 |
+
# Process reconstruction in mini-batches matching AutoGaze batch structure
|
| 301 |
+
num_autogaze_batches = len(gaze_output['num_gazing_each_frame_list'])
|
| 302 |
+
print(f'Reconstructing {num_chunks} chunks in {num_autogaze_batches} batches (aligned with AutoGaze batches)...')
|
| 303 |
+
|
| 304 |
+
chunk_idx = 0
|
| 305 |
+
for autogaze_batch_idx in range(num_autogaze_batches):
|
| 306 |
+
batch_size = gaze_output['batch_sizes'][autogaze_batch_idx]
|
| 307 |
+
start_chunk_idx = chunk_idx
|
| 308 |
+
end_chunk_idx = chunk_idx + batch_size
|
| 309 |
+
|
| 310 |
+
print(f'Reconstructing chunks {start_chunk_idx+1}-{end_chunk_idx}/{num_chunks}...')
|
| 311 |
+
|
| 312 |
+
# Extract videos for all chunks in this AutoGaze batch
|
| 313 |
+
batch_videos = video_chunks_flat[start_chunk_idx:end_chunk_idx].to(device) # (batch_size, 16, C, H, W)
|
| 314 |
+
|
| 315 |
+
# Extract gazing data for all chunks in this AutoGaze batch
|
| 316 |
+
batch_gazing_pos = gaze_output['gazing_pos'][start_chunk_idx:end_chunk_idx]
|
| 317 |
+
batch_gazing_mask = [scale_mask[start_chunk_idx:end_chunk_idx] for scale_mask in gaze_output['gazing_mask']]
|
| 318 |
+
batch_num_gazing_each_frame = gaze_output['num_gazing_each_frame_list'][autogaze_batch_idx]
|
| 319 |
+
|
| 320 |
+
# Trim to expected sequence length for this AutoGaze batch
|
| 321 |
+
expected_seq_len = batch_num_gazing_each_frame.sum().item()
|
| 322 |
+
batch_gazing_pos = batch_gazing_pos[:, :expected_seq_len]
|
| 323 |
+
|
| 324 |
+
chunk_idx = end_chunk_idx
|
| 325 |
+
|
| 326 |
+
batch_gaze_output = {
|
| 327 |
+
'gazing_pos': batch_gazing_pos,
|
| 328 |
+
'gazing_mask': batch_gazing_mask,
|
| 329 |
+
'num_gazing_each_frame': batch_num_gazing_each_frame,
|
| 330 |
+
'frame_sampling_rate': gaze_output['frame_sampling_rate'],
|
| 331 |
+
'num_vision_tokens_each_frame': gaze_output['num_vision_tokens_each_frame'],
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
if 'if_padded_gazing' in gaze_output:
|
| 335 |
+
batch_if_padded = gaze_output['if_padded_gazing'][start_chunk_idx:end_chunk_idx]
|
| 336 |
+
batch_if_padded = batch_if_padded[:, :expected_seq_len]
|
| 337 |
+
batch_gaze_output['if_padded_gazing'] = batch_if_padded
|
| 338 |
+
|
| 339 |
+
# Reconstruct frame by frame for this batch
|
| 340 |
+
batch_video_dict = {"video": batch_videos}
|
| 341 |
+
# Pre-allocate batch_reconstructions tensor to avoid list + stack memory spike
|
| 342 |
+
batch_reconstructions = torch.zeros((16, batch_size, C, 224, 224), device=device)
|
| 343 |
+
for frame_idx in range(16):
|
| 344 |
+
# Update progress for each frame
|
| 345 |
+
frame_pct = int(((autogaze_batch_idx * 16 + frame_idx + 1) / (num_autogaze_batches * 16)) * 100)
|
| 346 |
+
if progress_callback:
|
| 347 |
+
progress_callback(0.5 + 0.4 * ((autogaze_batch_idx * 16 + frame_idx + 1) / (num_autogaze_batches * 16)), f"Reconstruction progress: {frame_pct}%")
|
| 348 |
+
yield None
|
| 349 |
+
|
| 350 |
+
task_output = task.forward_output(batch_video_dict, batch_gaze_output, frame_idx_to_reconstruct=[frame_idx])
|
| 351 |
+
batch_reconstructions[frame_idx] = task_output['reconstruction'][:, 0] # (recon_batch_size, C, H, W)
|
| 352 |
+
del task_output
|
| 353 |
+
|
| 354 |
+
# Reorder from (16, recon_batch_size, C, H, W) to (recon_batch_size, 16, C, H, W) to match expected chunk ordering
|
| 355 |
+
# batch_reconstructions already in shape (16, recon_batch_size, C, H, W)
|
| 356 |
+
batch_reconstructions = rearrange(batch_reconstructions, 't b c h w -> (b t) c h w') # (recon_batch_size * 16, C, H, W)
|
| 357 |
+
|
| 358 |
+
# Write directly into pre-allocated tensor
|
| 359 |
+
batch_size_frames = batch_reconstructions.shape[0]
|
| 360 |
+
reconstruction_chunks[frame_idx_counter:frame_idx_counter+batch_size_frames] = batch_reconstructions.cpu()
|
| 361 |
+
frame_idx_counter += batch_size_frames
|
| 362 |
+
|
| 363 |
+
# Clean up batch-specific variables
|
| 364 |
+
del batch_videos, batch_gaze_output, batch_video_dict, batch_reconstructions
|
| 365 |
+
print('Reconstruction complete.')
|
| 366 |
+
# Manually reverse the mean/std normalization to get back to [0, 1] range
|
| 367 |
+
mean = torch.tensor(transform.image_mean).view(1, 3, 1, 1).to(reconstruction_chunks.device)
|
| 368 |
+
std = torch.tensor(transform.image_std).view(1, 3, 1, 1).to(reconstruction_chunks.device)
|
| 369 |
+
reconstruction_chunks = reconstruction_chunks * std + mean
|
| 370 |
+
|
| 371 |
+
# Clean up video chunks and gaze output to free GPU memory (keep gazing_masks_batched for later)
|
| 372 |
+
del video_chunks, video_chunks_flat, gaze_output
|
| 373 |
+
|
| 374 |
+
# Reshape chunks back to original structure (nt, nh, nw already calculated earlier)
|
| 375 |
+
print(f'Reshaping reconstructed chunks back to video tensor...')
|
| 376 |
+
reconstruction_tensor = rearrange(reconstruction_chunks, '(nt nh nw t) C h w -> (nt t) C (nh h) (nw w)', nt=nt, nh=nh, nw=nw, t=16)
|
| 377 |
+
reconstruction_tensor = reconstruction_tensor[:T, :, :H, :W] # Remove padding
|
| 378 |
+
|
| 379 |
+
# Move reconstruction to GPU for visualization
|
| 380 |
+
reconstruction_tensor = reconstruction_tensor.to(device)
|
| 381 |
+
|
| 382 |
+
gazing_mask_assembled = []
|
| 383 |
+
for scale_idx in range(4):
|
| 384 |
+
scale_masks_stacked = gazing_masks_batched[scale_idx]
|
| 385 |
+
|
| 386 |
+
# Reshape: (num_chunks, 16, num_patches) -> (num_chunks * 16, num_patches)
|
| 387 |
+
scale_masks_flat = scale_masks_stacked.reshape(-1, scale_masks_stacked.shape[-1])
|
| 388 |
+
|
| 389 |
+
# Rearrange back to original video structure
|
| 390 |
+
scale_masks_reshaped = rearrange(scale_masks_flat, '(nt nh nw t) n -> (nt t) (nh nw) n', nt=nt, nh=nh, nw=nw, t=16)
|
| 391 |
+
scale_masks_reshaped = scale_masks_reshaped[:T] # Remove temporal padding
|
| 392 |
+
|
| 393 |
+
gazing_mask_assembled.append(scale_masks_reshaped)
|
| 394 |
+
|
| 395 |
+
del scale_masks_stacked, scale_masks_flat, scale_masks_reshaped
|
| 396 |
+
|
| 397 |
+
del gazing_masks_batched
|
| 398 |
+
|
| 399 |
+
pct = total_gazing_tokens / total_possible_tokens
|
| 400 |
+
|
| 401 |
+
# Move original video to GPU for visualization
|
| 402 |
+
video_viz = video_tensor_original.to(device)
|
| 403 |
+
|
| 404 |
+
# Generate frame-by-frame visualizations
|
| 405 |
+
original_frames = []
|
| 406 |
+
composite_frames = []
|
| 407 |
+
reconstruction_frames = []
|
| 408 |
+
scales_stitch_frames = []
|
| 409 |
+
|
| 410 |
+
print('Visualizing...')
|
| 411 |
+
if progress_callback:
|
| 412 |
+
progress_callback(0.9, "Visualizing...")
|
| 413 |
+
yield None
|
| 414 |
+
for t in trange(T):
|
| 415 |
+
# Original frame
|
| 416 |
+
frame = video_viz[t].permute(1, 2, 0)
|
| 417 |
+
frame = torch.clip(frame, 0, 1)
|
| 418 |
+
frame_uint8 = (frame * 255).byte().cpu().numpy()
|
| 419 |
+
original_frames.append(frame_uint8)
|
| 420 |
+
|
| 421 |
+
# Reconstruction frame
|
| 422 |
+
recon_frame = reconstruction_tensor[t].permute(1, 2, 0)
|
| 423 |
+
recon_frame = torch.clip(recon_frame, 0, 1)
|
| 424 |
+
recon_uint8 = (recon_frame * 255).byte().cpu().numpy()
|
| 425 |
+
reconstruction_frames.append(recon_uint8)
|
| 426 |
+
|
| 427 |
+
composite = torch.zeros((H, W, 3)).to(device)
|
| 428 |
+
scales = setup['scales']
|
| 429 |
+
alpha_values = [0.4, 0.5, 0.6, 0.7] # Per-scale opacity (coarse to fine)
|
| 430 |
+
colors = [
|
| 431 |
+
[1.0, 0.0, 0.0], # Scale 0 (coarsest): Red
|
| 432 |
+
[0.0, 1.0, 0.0], # Scale 1: Green
|
| 433 |
+
[0.0, 0.0, 1.0], # Scale 2: Blue
|
| 434 |
+
[1.0, 1.0, 0.0] # Scale 3 (finest): Yellow
|
| 435 |
+
]
|
| 436 |
+
|
| 437 |
+
for scale_idx in range(4):
|
| 438 |
+
scale = scales[scale_idx]
|
| 439 |
+
scale_h = int(scale * H / 224)
|
| 440 |
+
scale_w = int(scale * W / 224)
|
| 441 |
+
|
| 442 |
+
# Get mask for this scale and frame
|
| 443 |
+
mask = gazing_mask_assembled[scale_idx][t] # (nh * nw, num_patches)
|
| 444 |
+
|
| 445 |
+
# print(f'Frame {t}, Scale {scale}: mask shape {mask.shape}')
|
| 446 |
+
# print(mask)
|
| 447 |
+
# print()
|
| 448 |
+
|
| 449 |
+
# Reshape mask: (nh * nw, num_patches) where num_patches = s^2
|
| 450 |
+
num_patches_per_chunk = mask.shape[-1]
|
| 451 |
+
s = int(num_patches_per_chunk ** 0.5)
|
| 452 |
+
|
| 453 |
+
# Rearrange to 2D spatial grid
|
| 454 |
+
mask_2d = rearrange(mask, '(nh nw) (h w) -> (nh h) (nw w)', nh=nh, nw=nw, h=s, w=s)
|
| 455 |
+
|
| 456 |
+
# Convert to tensor if needed
|
| 457 |
+
if isinstance(mask_2d, np.ndarray):
|
| 458 |
+
mask_tensor = torch.from_numpy(mask_2d)
|
| 459 |
+
else:
|
| 460 |
+
mask_tensor = mask_2d
|
| 461 |
+
|
| 462 |
+
mask_resized = F.interpolate(mask_tensor.unsqueeze(0).unsqueeze(0).float(), size=(scale_h, scale_w), mode='nearest')[0, 0]
|
| 463 |
+
|
| 464 |
+
frame_tensor = video_viz[t]
|
| 465 |
+
frame_scaled = F.interpolate(frame_tensor.unsqueeze(0), size=(scale_h, scale_w), mode='bicubic', align_corners=False).squeeze().clamp(0, 1)
|
| 466 |
+
|
| 467 |
+
frame_scaled_masked = frame_scaled * mask_resized.unsqueeze(0)
|
| 468 |
+
|
| 469 |
+
# Upsample both masked frame and mask to full size
|
| 470 |
+
frame_upsampled = F.interpolate(frame_scaled_masked.unsqueeze(0), size=(H, W), mode='nearest').squeeze() #.cpu().numpy()
|
| 471 |
+
mask_upsampled = F.interpolate(mask_resized.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest').squeeze() #.cpu().numpy()
|
| 472 |
+
|
| 473 |
+
frame_upsampled = frame_upsampled.permute(1, 2, 0)
|
| 474 |
+
|
| 475 |
+
composite = composite * (1 - mask_upsampled[:, :, None] * alpha_values[scale_idx]) + frame_upsampled * alpha_values[scale_idx]
|
| 476 |
+
|
| 477 |
+
composite_np = composite.detach().cpu().numpy()
|
| 478 |
+
composite_np = (composite_np - composite_np.min()) / (composite_np.max() - composite_np.min() + 1e-8)
|
| 479 |
+
composite_uint8 = (composite_np * 255).astype(np.uint8)
|
| 480 |
+
composite_frames.append(composite_uint8)
|
| 481 |
+
|
| 482 |
+
# Create individual scale visualizations for horizontal stitch
|
| 483 |
+
scale_composites = []
|
| 484 |
+
label_bar_height = 30
|
| 485 |
+
|
| 486 |
+
for scale_idx in range(4):
|
| 487 |
+
scale = scales[scale_idx]
|
| 488 |
+
scale_h = int(scale * H / 224)
|
| 489 |
+
scale_w = int(scale * W / 224)
|
| 490 |
+
|
| 491 |
+
# Get mask for this scale and frame
|
| 492 |
+
mask = gazing_mask_assembled[scale_idx][t]
|
| 493 |
+
|
| 494 |
+
# Reshape mask to 2D spatial grid
|
| 495 |
+
num_patches_per_chunk = mask.shape[-1]
|
| 496 |
+
s = int(num_patches_per_chunk ** 0.5)
|
| 497 |
+
mask_2d = rearrange(mask, '(nh nw) (h w) -> (nh h) (nw w)', nh=nh, nw=nw, h=s, w=s)
|
| 498 |
+
|
| 499 |
+
if isinstance(mask_2d, np.ndarray):
|
| 500 |
+
mask_tensor_scale = torch.from_numpy(mask_2d)
|
| 501 |
+
else:
|
| 502 |
+
mask_tensor_scale = mask_2d
|
| 503 |
+
|
| 504 |
+
mask_resized_scale = F.interpolate(mask_tensor_scale.unsqueeze(0).unsqueeze(0).float(), size=(scale_h, scale_w), mode='nearest')[0, 0]
|
| 505 |
+
|
| 506 |
+
frame_tensor_scale = video_viz[t]
|
| 507 |
+
frame_scaled_scale = F.interpolate(frame_tensor_scale.unsqueeze(0), size=(scale_h, scale_w), mode='bicubic', align_corners=False).squeeze().clamp(0, 1)
|
| 508 |
+
|
| 509 |
+
# Apply gazing pattern: gazed tiles = 1.0 brightness, ungazed tiles = 0.2 brightness
|
| 510 |
+
frame_scaled_permuted = frame_scaled_scale.permute(1, 2, 0)
|
| 511 |
+
scale_composite = frame_scaled_permuted * (mask_resized_scale[:, :, None] * 1.0 + (1 - mask_resized_scale[:, :, None]) * 0.2)
|
| 512 |
+
|
| 513 |
+
scale_composite_np = scale_composite.detach().cpu().numpy()
|
| 514 |
+
scale_composite_np = np.clip(scale_composite_np, 0, 1)
|
| 515 |
+
scale_composite_uint8 = (scale_composite_np * 255).astype(np.uint8)
|
| 516 |
+
|
| 517 |
+
# Resize visualization to common display height first (preserving aspect ratio)
|
| 518 |
+
display_width = int(scale_w * H / scale_h)
|
| 519 |
+
scale_composite_pil = Image.fromarray(scale_composite_uint8)
|
| 520 |
+
scale_composite_resized = scale_composite_pil.resize((display_width, H), Image.NEAREST)
|
| 521 |
+
scale_composite_resized_np = np.array(scale_composite_resized)
|
| 522 |
+
|
| 523 |
+
# Create label bar matching the resized visualization width
|
| 524 |
+
label_bar = np.ones((label_bar_height, display_width, 3), dtype=np.uint8) * 255
|
| 525 |
+
label_bar_pil = Image.fromarray(label_bar)
|
| 526 |
+
draw = ImageDraw.Draw(label_bar_pil)
|
| 527 |
+
try:
|
| 528 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
|
| 529 |
+
except:
|
| 530 |
+
font = ImageFont.load_default()
|
| 531 |
+
|
| 532 |
+
label = f"Scale {scale_idx + 1}"
|
| 533 |
+
draw.text((5, 5), label, fill=(0, 0, 0), font=font)
|
| 534 |
+
label_bar_np = np.array(label_bar_pil)
|
| 535 |
+
|
| 536 |
+
# Stack label bar above the visualization
|
| 537 |
+
scale_with_label = np.vstack([label_bar_np, scale_composite_resized_np])
|
| 538 |
+
|
| 539 |
+
scale_composites.append(scale_with_label)
|
| 540 |
+
|
| 541 |
+
# Add 10px white padding between scales
|
| 542 |
+
padding = np.ones((H + label_bar_height, 10, 3), dtype=np.uint8) * 255
|
| 543 |
+
|
| 544 |
+
# Concatenate all scales horizontally with padding
|
| 545 |
+
stitched = scale_composites[0]
|
| 546 |
+
for i in range(1, 4):
|
| 547 |
+
stitched = np.concatenate([stitched, padding, scale_composites[i]], axis=1)
|
| 548 |
+
|
| 549 |
+
# Add white padding at the top to prevent Gradio's label from blocking content
|
| 550 |
+
top_padding = np.ones((50, stitched.shape[1], 3), dtype=np.uint8) * 255
|
| 551 |
+
stitched = np.vstack([top_padding, stitched])
|
| 552 |
+
|
| 553 |
+
scales_stitch_frames.append(stitched)
|
| 554 |
+
|
| 555 |
+
del frame_tensor, mask_tensor, mask_resized, frame_scaled, frame_scaled_masked, frame_upsampled, mask_upsampled
|
| 556 |
+
|
| 557 |
+
del gazing_mask_assembled
|
| 558 |
+
|
| 559 |
+
del video_tensor_original, reconstruction_tensor, video_viz, reconstruction_chunks
|
| 560 |
+
|
| 561 |
+
if device == 'cuda':
|
| 562 |
+
torch.cuda.empty_cache()
|
| 563 |
+
|
| 564 |
+
yield {
|
| 565 |
+
'original_frames': original_frames,
|
| 566 |
+
'gazing_frames': composite_frames,
|
| 567 |
+
'reconstruction_frames': reconstruction_frames,
|
| 568 |
+
'scales_stitch_frames': scales_stitch_frames,
|
| 569 |
+
'fps': fps,
|
| 570 |
+
'gazing_pct': pct,
|
| 571 |
+
'total_gazing_tokens': total_gazing_tokens,
|
| 572 |
+
'total_possible_tokens': total_possible_tokens
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def save_video(frames, output_path, fps):
|
| 577 |
+
with imageio.get_writer(output_path, fps=fps, format='FFMPEG', codec='libx264', pixelformat='yuv420p') as writer:
|
| 578 |
+
for frame in frames:
|
| 579 |
+
writer.append_data(frame)
|
environment.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: gengaze_demo
|
| 2 |
+
channels:
|
| 3 |
+
- nvidia
|
| 4 |
+
- conda-forge
|
| 5 |
+
- defaults
|
| 6 |
+
dependencies:
|
| 7 |
+
- python=3.10
|
| 8 |
+
- pip
|
| 9 |
+
- nvidia::cuda-toolkit=12.6
|
| 10 |
+
- pip:
|
| 11 |
+
- torch==2.7.1
|
| 12 |
+
- torchvision==0.22.1
|
| 13 |
+
- torchaudio==2.7.1
|
| 14 |
+
- numpy==1.26.4
|
| 15 |
+
- pillow==10.4.0
|
| 16 |
+
- matplotlib==3.10.1
|
| 17 |
+
- gradio>=4.0.0
|
| 18 |
+
- spaces
|
| 19 |
+
- flash_attn==2.8.0.post2
|
| 20 |
+
- hydra-core==1.3.2
|
| 21 |
+
- wandb==0.21.0
|
| 22 |
+
- loguru==0.7.3
|
| 23 |
+
- timm==1.0.15
|
| 24 |
+
- tqdm==4.67.1
|
| 25 |
+
- transformers==4.53.0
|
| 26 |
+
- omegaconf==2.3.0
|
| 27 |
+
- einops==0.8.1
|
| 28 |
+
- av==14.4.0
|
| 29 |
+
- imageio==2.37.0
|
example_inputs/aerial.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e90d807c5d0438ff80112a2634b8cc10c4700cfbeaede96c5bd931035f170f46
|
| 3 |
+
size 298170
|
example_inputs/aerial_thumb.png
ADDED
|
|
example_inputs/doorbell.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f8667f28dd39c89b630e4ad29cf49e8bc82fb5ed26196fe0371b0f10e54a2ba9
|
| 3 |
+
size 460064
|
example_inputs/doorbell_thumb.png
ADDED
|
|
example_inputs/tomjerry.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e0a7d90ea96f817268e44dc38f58bdac3348df7f64a3eb54bba79ed5e7df7a3
|
| 3 |
+
size 435371
|
example_inputs/tomjerry_thumb.png
ADDED
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
pillow==10.4.0
|
| 3 |
+
matplotlib==3.10.1
|
| 4 |
+
gradio>=4.0.0
|
| 5 |
+
spaces
|
| 6 |
+
hydra-core==1.3.2
|
| 7 |
+
wandb==0.21.0
|
| 8 |
+
loguru==0.7.3
|
| 9 |
+
timm==1.0.15
|
| 10 |
+
tqdm==4.67.1
|
| 11 |
+
transformers==4.53.0
|
| 12 |
+
omegaconf==2.3.0
|
| 13 |
+
einops==0.8.1
|
| 14 |
+
av==14.4.0
|
| 15 |
+
imageio==2.37.0
|