Anthony Liang commited on
Commit
8bcb31b
·
1 Parent(s): 7097e14
Files changed (2) hide show
  1. app.py +30 -10
  2. eval_viz_utils.py +229 -2
app.py CHANGED
@@ -27,7 +27,11 @@ from typing import Any, List, Optional, Tuple
27
 
28
  from dataset_types import Trajectory, ProgressSample, PreferenceSample
29
  from eval_utils import build_payload, post_batch_npy
30
- from eval_viz_utils import create_combined_progress_success_plot, extract_frames
 
 
 
 
31
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
32
 
33
  logger = logging.getLogger(__name__)
@@ -421,22 +425,24 @@ def process_single_video(
421
  server_url: str = "",
422
  fps: float = 1.0,
423
  use_frame_steps: bool = False,
424
- ) -> Tuple[Optional[str], Optional[str]]:
425
- """Process single video for progress and success predictions using eval server."""
 
 
426
  # Get server URL from state if not provided
427
  if not server_url:
428
  server_url = _server_state.get("server_url")
429
 
430
  if not server_url:
431
- return None, "Please select a model from the dropdown above and ensure it's connected."
432
 
433
  if video_path is None:
434
- return None, "Please provide a video."
435
 
436
  try:
437
  frames_array = extract_frames(video_path, fps=fps)
438
  if frames_array is None or frames_array.size == 0:
439
- return None, "Could not extract frames from video."
440
 
441
  # Convert frames to (T, H, W, C) numpy array with uint8 values
442
  if frames_array.dtype != np.uint8:
@@ -520,11 +526,24 @@ def process_single_video(
520
  if success_array is not None and len(success_array) > 0:
521
  info_text += f"**Final success probability:** {success_array[-1]:.3f}\n"
522
 
523
- # Return combined plot (which includes success if available)
524
- return progress_plot, info_text
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
  except Exception as e:
527
- return None, f"Error processing video: {str(e)}"
528
 
529
 
530
  def process_two_videos(
@@ -781,6 +800,7 @@ with demo:
781
 
782
  with gr.Column():
783
  progress_plot = gr.Image(label="Progress & Success Prediction", height=320)
 
784
  info_output = gr.Markdown("")
785
  gr.Markdown("---")
786
  gr.Markdown("**Examples**")
@@ -1004,7 +1024,7 @@ with demo:
1004
  fps_input_single,
1005
  use_frame_steps_single,
1006
  ],
1007
- outputs=[progress_plot, info_output],
1008
  api_name="process_single_video",
1009
  )
1010
 
 
27
 
28
  from dataset_types import Trajectory, ProgressSample, PreferenceSample
29
  from eval_utils import build_payload, post_batch_npy
30
+ from eval_viz_utils import (
31
+ create_combined_progress_success_plot,
32
+ create_progress_success_gif,
33
+ extract_frames,
34
+ )
35
  from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
36
 
37
  logger = logging.getLogger(__name__)
 
425
  server_url: str = "",
426
  fps: float = 1.0,
427
  use_frame_steps: bool = False,
428
+ ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
429
+ """Process single video for progress and success predictions using eval server.
430
+ Returns (static_plot_path, video_path, info_text). video_path is the 5 sec MP4 animation; may be None if creation fails.
431
+ """
432
  # Get server URL from state if not provided
433
  if not server_url:
434
  server_url = _server_state.get("server_url")
435
 
436
  if not server_url:
437
+ return None, None, "Please select a model from the dropdown above and ensure it's connected."
438
 
439
  if video_path is None:
440
+ return None, None, "Please provide a video."
441
 
442
  try:
443
  frames_array = extract_frames(video_path, fps=fps)
444
  if frames_array is None or frames_array.size == 0:
445
+ return None, None, "Could not extract frames from video."
446
 
447
  # Convert frames to (T, H, W, C) numpy array with uint8 values
448
  if frames_array.dtype != np.uint8:
 
526
  if success_array is not None and len(success_array) > 0:
527
  info_text += f"**Final success probability:** {success_array[-1]:.3f}\n"
528
 
529
+ # Animated MP4: progress + success curves (5 sec clip) with optional video panel
530
+ video_path = None
531
+ if len(progress_array) > 0:
532
+ mp4_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
533
+ mp4_file.close()
534
+ video_path = create_progress_success_gif(
535
+ progress_pred=progress_array,
536
+ success_data=success_binary if success_binary is not None else success_array,
537
+ video_frames=frames_array,
538
+ output_path=mp4_file.name,
539
+ title=task_text,
540
+ duration_sec=5.0,
541
+ )
542
+
543
+ return progress_plot, video_path, info_text
544
 
545
  except Exception as e:
546
+ return None, None, f"Error processing video: {str(e)}"
547
 
548
 
549
  def process_two_videos(
 
800
 
801
  with gr.Column():
802
  progress_plot = gr.Image(label="Progress & Success Prediction", height=320)
803
+ progress_video = gr.Video(label="Animated Progress & Success (5 sec MP4)", height=320)
804
  info_output = gr.Markdown("")
805
  gr.Markdown("---")
806
  gr.Markdown("**Examples**")
 
1024
  fps_input_single,
1025
  use_frame_steps_single,
1026
  ],
1027
+ outputs=[progress_plot, progress_video, info_output],
1028
  api_name="process_single_video",
1029
  )
1030
 
eval_viz_utils.py CHANGED
@@ -8,11 +8,24 @@ import os
8
  import logging
9
  import tempfile
10
  import numpy as np
 
 
11
  import matplotlib.pyplot as plt
 
12
  import decord
13
 
14
  logger = logging.getLogger(__name__)
15
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def create_combined_progress_success_plot(
18
  progress_pred: np.ndarray,
@@ -49,13 +62,13 @@ def create_combined_progress_success_plot(
49
 
50
  if has_success_binary:
51
  # Three subplots: progress, success (binary), success_probs
52
- fig, axs = plt.subplots(1, 3, figsize=(15, 3.5))
53
  ax = axs[0] # Progress subplot
54
  ax2 = axs[1] # Success subplot (binary)
55
  ax3 = axs[2] # Success probs subplot
56
  else:
57
  # Single subplot: progress only
58
- fig, ax = plt.subplots(figsize=(6, 3.5))
59
  ax2 = None
60
  ax3 = None
61
 
@@ -203,3 +216,217 @@ def extract_frames(video_path: str, fps: float = 1.0, max_frames: int = 64) -> n
203
  except Exception as e:
204
  logger.error(f"Error extracting frames from {video_path}: {e}")
205
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import logging
9
  import tempfile
10
  import numpy as np
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
  import matplotlib.pyplot as plt
14
+ import matplotlib.ticker as ticker
15
  import decord
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Colors and layout for progress/success animation (Robometer red)
20
+ PROGRESS_COLOR = "#B20000"
21
+ SUCCESS_COLOR = "#B20000"
22
+ THEME_LIGHT = {"facecolor": "white", "text_color": "black", "spine_color": "#333333"}
23
+
24
+ # Serif font (Palatino) for plots
25
+ plt.rcParams["font.family"] = "serif"
26
+ plt.rcParams["font.serif"] = ["Palatino", "Palatino Linotype", "DejaVu Serif", "serif"]
27
+ plt.rcParams["font.size"] = 11
28
+
29
 
30
  def create_combined_progress_success_plot(
31
  progress_pred: np.ndarray,
 
62
 
63
  if has_success_binary:
64
  # Three subplots: progress, success (binary), success_probs
65
+ fig, axs = plt.subplots(1, 3, figsize=(18, 3.5))
66
  ax = axs[0] # Progress subplot
67
  ax2 = axs[1] # Success subplot (binary)
68
  ax3 = axs[2] # Success probs subplot
69
  else:
70
  # Single subplot: progress only
71
+ fig, ax = plt.subplots(figsize=(7, 3.5))
72
  ax2 = None
73
  ax3 = None
74
 
 
216
  except Exception as e:
217
  logger.error(f"Error extracting frames from {video_path}: {e}")
218
  return None
219
+
220
+
221
+ def resize_frames_keep_aspect(
222
+ frames: np.ndarray,
223
+ max_edge: int = 480,
224
+ ) -> np.ndarray:
225
+ """Resize video frames so the longer edge is at most max_edge, preserving aspect ratio.
226
+ Use when creating videos so the image is not stretched. Uses scipy if available.
227
+ """
228
+ if frames is None or frames.size == 0 or frames.ndim != 4:
229
+ return frames
230
+ t, h, w, c = frames.shape
231
+ if h <= 0 or w <= 0:
232
+ return frames
233
+ scale = min(max_edge / max(h, w), 1.0)
234
+ if scale >= 1.0:
235
+ return frames
236
+ new_h = max(1, round(h * scale))
237
+ new_w = max(1, round(w * scale))
238
+ try:
239
+ from scipy.ndimage import zoom
240
+ zoom_factors = (1.0, new_h / h, new_w / w, 1.0)
241
+ out = zoom(frames.astype(np.float64), zoom_factors, order=1)
242
+ return np.clip(out, 0, 255).astype(np.uint8)
243
+ except ImportError:
244
+ return frames
245
+
246
+
247
+ def _style_progress_ax(ax, theme: dict, ylabel: str = "Progress"):
248
+ """Style a progress or success axis (shared look)."""
249
+ ax.set_facecolor(theme["facecolor"])
250
+ ax.set_ylim(-0.05, 1.05)
251
+ ax.set_xlabel("")
252
+ ax.set_ylabel(ylabel, fontsize=12, fontweight="bold", color=theme["text_color"])
253
+ ax.spines["left"].set_color(theme["spine_color"])
254
+ ax.spines["bottom"].set_color(theme["spine_color"])
255
+ ax.spines["right"].set_visible(False)
256
+ ax.spines["top"].set_visible(False)
257
+ ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True, nbins=8))
258
+ ax.set_yticks([0, 0.5, 1.0])
259
+ ax.tick_params(axis="both", labelsize=10, colors=theme["text_color"])
260
+
261
+
262
+ def create_progress_success_gif(
263
+ progress_pred: np.ndarray,
264
+ success_data: Optional[np.ndarray] = None,
265
+ video_frames: Optional[np.ndarray] = None,
266
+ output_path: Optional[str] = None,
267
+ title: Optional[str] = None,
268
+ duration_sec: float = 5.0,
269
+ theme: Optional[dict] = None,
270
+ ) -> Optional[str]:
271
+ """Create an animated MP4: progress and success curves growing frame-by-frame (optional video on left).
272
+
273
+ Uses light theme by default for web UI. Output is always 5 seconds (duration_sec); fps is
274
+ computed as num_frames / duration_sec. Saves to output_path as .mp4. Returns path if saved, None on error.
275
+ """
276
+ from matplotlib.animation import FuncAnimation
277
+
278
+ theme = theme or THEME_LIGHT
279
+ progress_pred = np.atleast_1d(progress_pred).astype(float)
280
+ num_frames = len(progress_pred)
281
+ if num_frames == 0:
282
+ return None
283
+
284
+ # FPS so the full animation runs for duration_sec (e.g. 5 seconds)
285
+ fps = max(1, round(num_frames / duration_sec))
286
+
287
+ success_padded = None
288
+ if success_data is not None and np.size(success_data) > 0:
289
+ s = np.atleast_1d(success_data).astype(float)
290
+ if len(s) < num_frames:
291
+ s = np.pad(s, (0, num_frames - len(s)), mode="edge")
292
+ success_padded = s
293
+
294
+ has_video = (
295
+ video_frames is not None
296
+ and getattr(video_frames, "shape", (0,))[0] >= num_frames
297
+ )
298
+ if has_video and video_frames.shape[0] > num_frames:
299
+ video_frames = video_frames[:num_frames]
300
+ elif has_video and video_frames.shape[0] < num_frames:
301
+ pad = np.repeat(video_frames[-1:], num_frames - video_frames.shape[0], axis=0)
302
+ video_frames = np.concatenate([video_frames, pad], axis=0)
303
+ if has_video:
304
+ video_frames = resize_frames_keep_aspect(video_frames, max_edge=480)
305
+
306
+ n_panels = 2 if success_padded is not None else 1
307
+ width_per_panel = 5.5
308
+ figsize = (width_per_panel * n_panels, 3.2) if not has_video else (2 + width_per_panel * n_panels, 3.2)
309
+
310
+ if has_video:
311
+ from matplotlib.gridspec import GridSpec
312
+ fig = plt.figure(facecolor=theme["facecolor"], figsize=figsize)
313
+ # Give plots more room: smaller video column, more wspace so video doesn't cover Progress
314
+ gs = GridSpec(1, 2, figure=fig, width_ratios=[0.85, n_panels], wspace=0.4)
315
+ ax_video = fig.add_subplot(gs[0])
316
+ ax_video.set_facecolor(theme["facecolor"])
317
+ ax_video.axis("off")
318
+ # Preserve aspect ratio so the video is not flattened
319
+ vid_im = ax_video.imshow(
320
+ np.clip(video_frames[0], 0, 255).astype(np.uint8)
321
+ if video_frames[0].ndim >= 3
322
+ else video_frames[0],
323
+ cmap="gray" if video_frames[0].ndim == 2 else None,
324
+ aspect="equal",
325
+ )
326
+ from matplotlib.gridspec import GridSpecFromSubplotSpec
327
+ gs_right = GridSpecFromSubplotSpec(1, n_panels, subplot_spec=gs[1], wspace=0.3)
328
+ axes = [fig.add_subplot(gs_right[0, j]) for j in range(n_panels)]
329
+ else:
330
+ fig, axes = plt.subplots(
331
+ 1, n_panels, figsize=figsize, facecolor=theme["facecolor"]
332
+ )
333
+ axes = np.atleast_1d(axes)
334
+ vid_im = None
335
+
336
+ lines = []
337
+ head_dots = []
338
+ for i in range(n_panels):
339
+ ax = axes[i]
340
+ if i == 1 and success_padded is not None:
341
+ _style_progress_ax(ax, theme, ylabel="Success")
342
+ ax.set_xlim(-0.5, num_frames)
343
+ line, = ax.plot([], [], lw=2.5, color=SUCCESS_COLOR, drawstyle="steps-post")
344
+ lines.append(line)
345
+ head_dots.append(None)
346
+ else:
347
+ _style_progress_ax(ax, theme, ylabel="Progress")
348
+ ax.set_xlim(-0.5, num_frames)
349
+ line, = ax.plot([], [], lw=2.5, color=PROGRESS_COLOR, drawstyle="steps-post")
350
+ head_dot = ax.scatter(
351
+ [], [], color=PROGRESS_COLOR, s=36, zorder=5,
352
+ edgecolors=PROGRESS_COLOR, facecolors="none",
353
+ )
354
+ lines.append(line)
355
+ head_dots.append(head_dot)
356
+
357
+ if title and str(title).strip():
358
+ # Place title inside figure top margin (rect keeps axes below 0.88)
359
+ fig.suptitle(
360
+ str(title).strip(),
361
+ fontsize=12,
362
+ fontweight="bold",
363
+ color=theme["text_color"],
364
+ y=0.94,
365
+ )
366
+
367
+ def update(frame):
368
+ out = []
369
+ if vid_im is not None and has_video:
370
+ idx = min(int(frame), video_frames.shape[0] - 1)
371
+ f = np.clip(video_frames[idx], 0, 255).astype(np.uint8)
372
+ if f.ndim == 2:
373
+ vid_im.set_cmap("gray")
374
+ vid_im.set_array(f)
375
+ out.append(vid_im)
376
+ for i in range(n_panels):
377
+ if i == 1 and success_padded is not None:
378
+ x = np.arange(int(frame) + 1)
379
+ y = success_padded[: int(frame) + 1]
380
+ if len(x) > 0 and len(y) > 0:
381
+ lines[i].set_data(x, y)
382
+ else:
383
+ x = np.arange(int(frame) + 1)
384
+ y = progress_pred[: int(frame) + 1]
385
+ if len(x) > 0 and len(y) > 0:
386
+ lines[i].set_data(x, y)
387
+ if head_dots[i] is not None:
388
+ head_dots[i].set_offsets([[frame, progress_pred[int(frame)]]])
389
+ out.append(lines[i])
390
+ if head_dots[i] is not None:
391
+ out.append(head_dots[i])
392
+ return out
393
+
394
+ # Leave extra top space so suptitle (task text) is not cut off; minimal horizontal pad for tight video
395
+ plt.tight_layout(rect=[0.01, 0, 0.99, 0.88], pad=0.3)
396
+ ani = FuncAnimation(
397
+ fig, update, frames=num_frames, interval=1000 / fps, blit=True
398
+ )
399
+
400
+ if not output_path:
401
+ fd, output_path = tempfile.mkstemp(suffix=".mp4")
402
+ os.close(fd)
403
+ # Normalize to .mp4
404
+ if output_path.endswith(".gif"):
405
+ output_path = output_path[:-4] + ".mp4"
406
+ if not output_path.lower().endswith(".mp4"):
407
+ output_path = output_path + ".mp4"
408
+ out_dir = os.path.dirname(output_path)
409
+ if out_dir:
410
+ os.makedirs(out_dir, exist_ok=True)
411
+
412
+ savefig_kwargs = {
413
+ "facecolor": theme["facecolor"],
414
+ "edgecolor": "none",
415
+ "bbox_inches": "tight",
416
+ "pad_inches": 0.12,
417
+ }
418
+ try:
419
+ ani.save(
420
+ output_path,
421
+ writer="ffmpeg",
422
+ fps=fps,
423
+ dpi=120,
424
+ savefig_kwargs=savefig_kwargs,
425
+ )
426
+ except Exception as e:
427
+ logger.warning(f"Could not save MP4 (ffmpeg?): {e}")
428
+ output_path = None
429
+ finally:
430
+ plt.close(fig)
431
+
432
+ return output_path