Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Utility functions for visualization in Robometer (RBM) evaluations. | |
| """ | |
| from typing import Optional | |
| import os | |
| import logging | |
| import tempfile | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as ticker | |
| import decord | |
| logger = logging.getLogger(__name__) | |
| # Colors and layout for progress/success animation (Robometer red) | |
| PROGRESS_COLOR = "#B20000" | |
| SUCCESS_COLOR = "#B20000" | |
| THEME_LIGHT = {"facecolor": "white", "text_color": "black", "spine_color": "#333333"} | |
| THEME_DARK = {"facecolor": "black", "text_color": "white", "spine_color": "#444444"} | |
| # Serif font (Palatino) for plots | |
| plt.rcParams["font.family"] = "serif" | |
| plt.rcParams["font.serif"] = ["Palatino", "Palatino Linotype", "DejaVu Serif", "serif"] | |
| plt.rcParams["font.size"] = 11 | |
| def wrap_title(text: str, max_chars_per_line: int = 48) -> str: | |
| """Wrap a long title onto at most two lines, breaking at word boundaries.""" | |
| if not text or not str(text).strip(): | |
| return text | |
| text = str(text).strip() | |
| if len(text) <= max_chars_per_line: | |
| return text | |
| words = text.split() | |
| line1, line2 = [], [] | |
| line1_len = 0 | |
| for w in words: | |
| need = len(w) + (1 if line1 else 0) # space before if not first | |
| if line2: | |
| line2.append(w) | |
| elif line1_len + need <= max_chars_per_line: | |
| line1.append(w) | |
| line1_len += need | |
| else: | |
| line2.append(w) | |
| if not line2: | |
| return text | |
| return " ".join(line1) + "\n" + " ".join(line2) | |
| def create_combined_progress_success_plot( | |
| progress_pred: np.ndarray, | |
| num_frames: int, | |
| success_binary: Optional[np.ndarray] = None, | |
| success_probs: Optional[np.ndarray] = None, | |
| success_labels: Optional[np.ndarray] = None, | |
| is_discrete_mode: bool = False, | |
| title: Optional[str] = None, | |
| loss: Optional[float] = None, | |
| pearson: Optional[float] = None, | |
| ) -> plt.Figure: | |
| """Create a combined plot with progress, success binary, and success probabilities. | |
| This function creates a unified plot with 1 subplot (progress only) or 3 subplots | |
| (progress, success binary, success probs), similar to the one used in compile_results.py. | |
| Args: | |
| progress_pred: Progress predictions array | |
| num_frames: Number of frames | |
| success_binary: Optional binary success predictions | |
| success_probs: Optional success probability predictions | |
| success_labels: Optional ground truth success labels | |
| is_discrete_mode: Whether progress is in discrete mode (deprecated, kept for compatibility) | |
| title: Optional title for the plot (if None, auto-generated from loss/pearson) | |
| loss: Optional loss value to display in title | |
| pearson: Optional pearson correlation to display in title | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| # Determine if we should show success plots | |
| has_success_binary = success_binary is not None and len(success_binary) == len(progress_pred) | |
| if has_success_binary: | |
| # Three subplots: progress, success (binary), success_probs | |
| fig, axs = plt.subplots(1, 3, figsize=(18, 3.5)) | |
| ax = axs[0] # Progress subplot | |
| ax2 = axs[1] # Success subplot (binary) | |
| ax3 = axs[2] # Success probs subplot | |
| else: | |
| # Single subplot: progress only | |
| fig, ax = plt.subplots(figsize=(7, 3.5)) | |
| ax2 = None | |
| ax3 = None | |
| # Plot progress | |
| ax.plot(progress_pred, linewidth=2) | |
| ax.set_ylabel("Progress") | |
| # Build title (wrap long task text onto two lines) | |
| if title is None: | |
| title_parts = ["Progress"] | |
| if loss is not None: | |
| title_parts.append(f"Loss: {loss:.3f}") | |
| if pearson is not None: | |
| title_parts.append(f"Pearson: {pearson:.2f}") | |
| title = ", ".join(title_parts) | |
| fig.suptitle(wrap_title(title)) | |
| # Set y-limits and ticks (always continuous since discrete is converted before this function) | |
| ax.set_ylim(0, 1) | |
| ax.spines["right"].set_visible(False) | |
| ax.spines["top"].set_visible(False) | |
| y_ticks = [0, 0.2, 0.4, 0.6, 0.8, 1.0] | |
| ax.set_yticks(y_ticks) | |
| # Setup success binary subplot | |
| if ax2 is not None: | |
| ax2.step(range(len(success_binary)), success_binary, where="post", linewidth=2, label="Predicted", color="blue") | |
| # Add ground truth success labels as green line if available | |
| if success_labels is not None and len(success_labels) == len(success_binary): | |
| ax2.step( | |
| range(len(success_labels)), | |
| success_labels, | |
| where="post", | |
| linewidth=2, | |
| label="Ground Truth", | |
| color="green", | |
| ) | |
| ax2.set_ylabel("Success (Binary)") | |
| ax2.set_ylim(-0.05, 1.05) | |
| ax2.spines["right"].set_visible(False) | |
| ax2.spines["top"].set_visible(False) | |
| ax2.set_yticks([0, 1]) | |
| ax2.legend() | |
| # Setup success probs subplot if available | |
| if ax3 is not None and success_probs is not None: | |
| ax3.plot(range(len(success_probs)), success_probs, linewidth=2, label="Success Prob", color="purple") | |
| # Add ground truth success labels as green line if available | |
| if success_labels is not None and len(success_labels) == len(success_probs): | |
| ax3.step( | |
| range(len(success_labels)), | |
| success_labels, | |
| where="post", | |
| linewidth=2, | |
| label="Ground Truth", | |
| color="green", | |
| linestyle="--", | |
| ) | |
| ax3.set_ylabel("Success Probability") | |
| ax3.set_ylim(-0.05, 1.05) | |
| ax3.spines["right"].set_visible(False) | |
| ax3.spines["top"].set_visible(False) | |
| ax3.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0]) | |
| ax3.legend() | |
| plt.tight_layout() | |
| return fig | |
| def extract_frames(video_path: str, fps: float = 1.0, max_frames: int = 64) -> np.ndarray: | |
| """Extract frames from video file as numpy array (T, H, W, C). | |
| Supports both local file paths and URLs (e.g., HuggingFace Hub URLs). | |
| Uses the provided ``fps`` to control how densely frames are sampled from | |
| the underlying video, but caps the total number of frames at ``max_frames`` | |
| to prevent memory issues. | |
| Args: | |
| video_path: Path to video file or URL | |
| fps: Frames per second to extract (default: 1.0) | |
| max_frames: Maximum number of frames to extract (default: 64). This prevents | |
| memory issues with long videos or high FPS settings. | |
| Returns: | |
| numpy array of shape (T, H, W, C) containing extracted frames, or None if error | |
| """ | |
| if video_path is None: | |
| return None | |
| if isinstance(video_path, tuple): | |
| video_path = video_path[0] | |
| # Check if it's a URL or local file | |
| is_url = video_path.startswith(("http://", "https://")) | |
| is_local_file = os.path.exists(video_path) if not is_url else False | |
| if not is_url and not is_local_file: | |
| logger.warning(f"Video path does not exist: {video_path}") | |
| return None | |
| try: | |
| # decord.VideoReader can handle both local files and URLs | |
| vr = decord.VideoReader(video_path, num_threads=1) | |
| total_frames = len(vr) | |
| # Determine native FPS; fall back to a reasonable default if unavailable | |
| try: | |
| native_fps = float(vr.get_avg_fps()) | |
| except Exception: | |
| native_fps = 1.0 | |
| # If user-specified fps is invalid or None, default to native fps | |
| if fps is None or fps <= 0: | |
| fps = native_fps | |
| # Compute how many frames we want based on desired fps | |
| # num_frames ≈ total_duration * fps = total_frames * (fps / native_fps) | |
| if native_fps > 0: | |
| desired_frames = int(round(total_frames * (fps / native_fps))) | |
| else: | |
| desired_frames = total_frames | |
| # Clamp to [1, total_frames] | |
| desired_frames = max(1, min(desired_frames, total_frames)) | |
| # IMPORTANT: Cap at max_frames to prevent memory issues | |
| # This is critical when fps is high or videos are long | |
| if desired_frames > max_frames: | |
| logger.warning( | |
| f"Requested {desired_frames} frames but capping at {max_frames} " | |
| f"to prevent memory issues (video has {total_frames} frames at {native_fps:.2f} fps, " | |
| f"requested extraction at {fps:.2f} fps)" | |
| ) | |
| desired_frames = max_frames | |
| # Evenly sample indices to match the desired number of frames | |
| if desired_frames == total_frames: | |
| frame_indices = list(range(total_frames)) | |
| else: | |
| frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist() | |
| frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C) | |
| del vr | |
| return frames_array | |
| except Exception as e: | |
| logger.error(f"Error extracting frames from {video_path}: {e}") | |
| return None | |
| def resize_frames_keep_aspect( | |
| frames: np.ndarray, | |
| max_edge: int = 480, | |
| ) -> np.ndarray: | |
| """Resize video frames so the longer edge is at most max_edge, preserving aspect ratio. | |
| Use when creating videos so the image is not stretched. Uses scipy if available. | |
| """ | |
| if frames is None or frames.size == 0 or frames.ndim != 4: | |
| return frames | |
| t, h, w, c = frames.shape | |
| if h <= 0 or w <= 0: | |
| return frames | |
| scale = min(max_edge / max(h, w), 1.0) | |
| if scale >= 1.0: | |
| return frames | |
| new_h = max(1, round(h * scale)) | |
| new_w = max(1, round(w * scale)) | |
| try: | |
| from scipy.ndimage import zoom | |
| zoom_factors = (1.0, new_h / h, new_w / w, 1.0) | |
| out = zoom(frames.astype(np.float64), zoom_factors, order=1) | |
| return np.clip(out, 0, 255).astype(np.uint8) | |
| except ImportError: | |
| return frames | |
| def _style_progress_ax(ax, theme: dict, ylabel: str = "Progress"): | |
| """Style a progress or success axis (shared look).""" | |
| ax.set_facecolor(theme["facecolor"]) | |
| ax.set_ylim(-0.05, 1.05) | |
| ax.set_xlabel("") | |
| ax.set_ylabel(ylabel, fontsize=12, fontweight="bold", color=theme["text_color"]) | |
| ax.spines["left"].set_color(theme["spine_color"]) | |
| ax.spines["bottom"].set_color(theme["spine_color"]) | |
| ax.spines["right"].set_visible(False) | |
| ax.spines["top"].set_visible(False) | |
| ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True, nbins=8)) | |
| ax.set_yticks([0, 0.5, 1.0]) | |
| ax.tick_params(axis="both", labelsize=10, colors=theme["text_color"]) | |
| def create_progress_success_gif( | |
| progress_pred: np.ndarray, | |
| success_data: Optional[np.ndarray] = None, | |
| video_frames: Optional[np.ndarray] = None, | |
| output_path: Optional[str] = None, | |
| title: Optional[str] = None, | |
| duration_sec: float = 5.0, | |
| theme: Optional[dict] = None, | |
| ) -> Optional[str]: | |
| """Create an animated MP4: progress and success curves growing frame-by-frame (optional video on left). | |
| Uses light theme by default for web UI. Output is always 5 seconds (duration_sec); fps is | |
| computed as num_frames / duration_sec. Saves to output_path as .mp4. Returns path if saved, None on error. | |
| """ | |
| from matplotlib.animation import FuncAnimation | |
| theme = theme or THEME_LIGHT | |
| progress_pred = np.atleast_1d(progress_pred).astype(float) | |
| num_frames = len(progress_pred) | |
| if num_frames == 0: | |
| return None | |
| # FPS so the full animation runs for duration_sec (e.g. 5 seconds) | |
| fps = max(1, round(num_frames / duration_sec)) | |
| success_padded = None | |
| if success_data is not None and np.size(success_data) > 0: | |
| s = np.atleast_1d(success_data).astype(float) | |
| if len(s) < num_frames: | |
| s = np.pad(s, (0, num_frames - len(s)), mode="edge") | |
| success_padded = s | |
| has_video = ( | |
| video_frames is not None | |
| and getattr(video_frames, "shape", (0,))[0] >= num_frames | |
| ) | |
| if has_video and video_frames.shape[0] > num_frames: | |
| video_frames = video_frames[:num_frames] | |
| elif has_video and video_frames.shape[0] < num_frames: | |
| pad = np.repeat(video_frames[-1:], num_frames - video_frames.shape[0], axis=0) | |
| video_frames = np.concatenate([video_frames, pad], axis=0) | |
| if has_video: | |
| video_frames = resize_frames_keep_aspect(video_frames, max_edge=480) | |
| n_panels = 2 if success_padded is not None else 1 | |
| width_per_panel = 5.5 | |
| figsize = (width_per_panel * n_panels, 3.2) if not has_video else (2 + width_per_panel * n_panels, 3.2) | |
| if has_video: | |
| from matplotlib.gridspec import GridSpec | |
| fig = plt.figure(facecolor=theme["facecolor"], figsize=figsize) | |
| # Give plots more room: smaller video column, more wspace so video doesn't cover Progress | |
| gs = GridSpec(1, 2, figure=fig, width_ratios=[0.85, n_panels], wspace=0.4) | |
| ax_video = fig.add_subplot(gs[0]) | |
| ax_video.set_facecolor(theme["facecolor"]) | |
| ax_video.axis("off") | |
| # Preserve aspect ratio so the video is not flattened | |
| vid_im = ax_video.imshow( | |
| np.clip(video_frames[0], 0, 255).astype(np.uint8) | |
| if video_frames[0].ndim >= 3 | |
| else video_frames[0], | |
| cmap="gray" if video_frames[0].ndim == 2 else None, | |
| aspect="equal", | |
| ) | |
| from matplotlib.gridspec import GridSpecFromSubplotSpec | |
| gs_right = GridSpecFromSubplotSpec(1, n_panels, subplot_spec=gs[1], wspace=0.3) | |
| axes = [fig.add_subplot(gs_right[0, j]) for j in range(n_panels)] | |
| else: | |
| fig, axes = plt.subplots( | |
| 1, n_panels, figsize=figsize, facecolor=theme["facecolor"] | |
| ) | |
| axes = np.atleast_1d(axes) | |
| vid_im = None | |
| lines = [] | |
| head_dots = [] | |
| for i in range(n_panels): | |
| ax = axes[i] | |
| if i == 1 and success_padded is not None: | |
| _style_progress_ax(ax, theme, ylabel="Success") | |
| ax.set_xlim(-0.5, num_frames) | |
| line, = ax.plot([], [], lw=2.5, color=SUCCESS_COLOR, drawstyle="steps-post") | |
| lines.append(line) | |
| head_dots.append(None) | |
| else: | |
| _style_progress_ax(ax, theme, ylabel="Progress") | |
| ax.set_xlim(-0.5, num_frames) | |
| line, = ax.plot([], [], lw=2.5, color=PROGRESS_COLOR, drawstyle="steps-post") | |
| head_dot = ax.scatter( | |
| [], [], color=PROGRESS_COLOR, s=36, zorder=5, | |
| edgecolors=PROGRESS_COLOR, facecolors="none", | |
| ) | |
| lines.append(line) | |
| head_dots.append(head_dot) | |
| if title and str(title).strip(): | |
| # Place title inside figure top margin; wrap long task text onto two lines | |
| fig.suptitle( | |
| wrap_title(str(title).strip()), | |
| fontsize=12, | |
| fontweight="bold", | |
| color=theme["text_color"], | |
| y=0.94, | |
| ) | |
| def update(frame): | |
| out = [] | |
| if vid_im is not None and has_video: | |
| idx = min(int(frame), video_frames.shape[0] - 1) | |
| f = np.clip(video_frames[idx], 0, 255).astype(np.uint8) | |
| if f.ndim == 2: | |
| vid_im.set_cmap("gray") | |
| vid_im.set_array(f) | |
| out.append(vid_im) | |
| for i in range(n_panels): | |
| if i == 1 and success_padded is not None: | |
| x = np.arange(int(frame) + 1) | |
| y = success_padded[: int(frame) + 1] | |
| if len(x) > 0 and len(y) > 0: | |
| lines[i].set_data(x, y) | |
| else: | |
| x = np.arange(int(frame) + 1) | |
| y = progress_pred[: int(frame) + 1] | |
| if len(x) > 0 and len(y) > 0: | |
| lines[i].set_data(x, y) | |
| if head_dots[i] is not None: | |
| head_dots[i].set_offsets([[frame, progress_pred[int(frame)]]]) | |
| out.append(lines[i]) | |
| if head_dots[i] is not None: | |
| out.append(head_dots[i]) | |
| return out | |
| # Leave extra top space so suptitle (task text) is not cut off; minimal horizontal pad for tight video | |
| plt.tight_layout(rect=[0.01, 0, 0.99, 0.88], pad=0.3) | |
| ani = FuncAnimation( | |
| fig, update, frames=num_frames, interval=1000 / fps, blit=True | |
| ) | |
| if not output_path: | |
| fd, output_path = tempfile.mkstemp(suffix=".mp4") | |
| os.close(fd) | |
| # Normalize to .mp4 | |
| if output_path.endswith(".gif"): | |
| output_path = output_path[:-4] + ".mp4" | |
| if not output_path.lower().endswith(".mp4"): | |
| output_path = output_path + ".mp4" | |
| out_dir = os.path.dirname(output_path) | |
| if out_dir: | |
| os.makedirs(out_dir, exist_ok=True) | |
| savefig_kwargs = { | |
| "facecolor": theme["facecolor"], | |
| "edgecolor": "none", | |
| "bbox_inches": "tight", | |
| "pad_inches": 0.12, | |
| } | |
| try: | |
| ani.save( | |
| output_path, | |
| writer="ffmpeg", | |
| fps=fps, | |
| dpi=120, | |
| savefig_kwargs=savefig_kwargs, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Could not save MP4 (ffmpeg?): {e}") | |
| output_path = None | |
| finally: | |
| plt.close(fig) | |
| return output_path | |