Spaces:
Running
Running
Anthony Liang commited on
Commit ·
f506da8
1
Parent(s): 6cf09b8
small ui updates
Browse files
app.py
CHANGED
|
@@ -227,7 +227,7 @@ def get_available_configs(dataset_name):
|
|
| 227 |
|
| 228 |
|
| 229 |
def get_trajectory_video_path(dataset, index, dataset_name):
|
| 230 |
-
"""Get video path from a trajectory in the dataset."""
|
| 231 |
try:
|
| 232 |
item = dataset[int(index)]
|
| 233 |
frames_data = item["frames"]
|
|
@@ -238,18 +238,25 @@ def get_trajectory_video_path(dataset, index, dataset_name):
|
|
| 238 |
video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
|
| 239 |
else:
|
| 240 |
video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
else:
|
| 243 |
-
return None, None
|
| 244 |
except Exception as e:
|
| 245 |
logger.error(f"Error getting trajectory video path: {e}")
|
| 246 |
-
return None, None
|
| 247 |
|
| 248 |
|
| 249 |
-
def extract_frames(video_path: str,
|
| 250 |
"""Extract frames from video file as numpy array (T, H, W, C).
|
| 251 |
-
|
| 252 |
Supports both local file paths and URLs (e.g., HuggingFace Hub URLs).
|
|
|
|
|
|
|
| 253 |
"""
|
| 254 |
if video_path is None:
|
| 255 |
return None
|
|
@@ -270,13 +277,31 @@ def extract_frames(video_path: str, max_frames: int = 16, fps: float = 1.0) -> n
|
|
| 270 |
vr = decord.VideoReader(video_path, num_threads=1)
|
| 271 |
total_frames = len(vr)
|
| 272 |
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
frame_indices = list(range(total_frames))
|
| 275 |
else:
|
| 276 |
-
frame_indices =
|
| 277 |
-
int(i * total_frames / max_frames)
|
| 278 |
-
for i in range(max_frames)
|
| 279 |
-
]
|
| 280 |
|
| 281 |
frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
|
| 282 |
del vr
|
|
@@ -303,7 +328,7 @@ def process_single_video(
|
|
| 303 |
return None, None, "Please provide a video."
|
| 304 |
|
| 305 |
try:
|
| 306 |
-
frames_array = extract_frames(video_path,
|
| 307 |
if frames_array is None or frames_array.size == 0:
|
| 308 |
return None, None, "Could not extract frames from video."
|
| 309 |
|
|
@@ -381,8 +406,8 @@ def process_dual_videos(
|
|
| 381 |
return "Please provide both videos.", None
|
| 382 |
|
| 383 |
try:
|
| 384 |
-
frames_array_a = extract_frames(video_a_path,
|
| 385 |
-
frames_array_b = extract_frames(video_b_path,
|
| 386 |
|
| 387 |
if frames_array_a is None or frames_array_a.size == 0:
|
| 388 |
return "Could not extract frames from video A.", None
|
|
@@ -483,7 +508,6 @@ def create_progress_plot(progress_pred: np.ndarray, num_frames: int) -> str:
|
|
| 483 |
ax.set_ylabel('Progress (0-1)', fontsize=18, fontweight='bold')
|
| 484 |
ax.set_title('Progress Prediction', fontsize=20, fontweight='bold')
|
| 485 |
ax.set_ylim([0, 1])
|
| 486 |
-
ax.legend(fontsize=14)
|
| 487 |
|
| 488 |
plt.tight_layout()
|
| 489 |
|
|
@@ -514,7 +538,6 @@ def create_success_plot(success_probs: np.ndarray, num_frames: int) -> str:
|
|
| 514 |
ax.set_ylabel('Success Probability (0-1)', fontsize=18, fontweight='bold')
|
| 515 |
ax.set_title('Success Prediction', fontsize=20, fontweight='bold')
|
| 516 |
ax.set_ylim([0, 1])
|
| 517 |
-
ax.legend(fontsize=14)
|
| 518 |
|
| 519 |
plt.tight_layout()
|
| 520 |
|
|
@@ -649,14 +672,18 @@ with demo:
|
|
| 649 |
load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm")
|
| 650 |
|
| 651 |
dataset_status_single = gr.Markdown("", visible=False)
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary")
|
| 661 |
|
| 662 |
gr.Markdown("---")
|
|
@@ -717,13 +744,104 @@ with demo:
|
|
| 717 |
def use_dataset_video(dataset, index, dataset_name):
|
| 718 |
"""Load video from dataset and update inputs."""
|
| 719 |
if dataset is None:
|
| 720 |
-
return None, "Complete the task", gr.update(value="No dataset loaded", visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
|
| 722 |
-
video_path, task = get_trajectory_video_path(dataset, index, dataset_name)
|
| 723 |
if video_path:
|
| 724 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
else:
|
| 726 |
-
return None, "Complete the task", gr.update(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
|
| 728 |
# Dataset selection handlers
|
| 729 |
dataset_name_single.change(
|
|
@@ -747,7 +865,27 @@ with demo:
|
|
| 747 |
use_dataset_video_btn.click(
|
| 748 |
fn=use_dataset_video,
|
| 749 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 750 |
-
outputs=[single_video_input, task_text_input, dataset_status_single]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
)
|
| 752 |
|
| 753 |
analyze_single_btn.click(
|
|
|
|
| 227 |
|
| 228 |
|
| 229 |
def get_trajectory_video_path(dataset, index, dataset_name):
|
| 230 |
+
"""Get video path and metadata from a trajectory in the dataset."""
|
| 231 |
try:
|
| 232 |
item = dataset[int(index)]
|
| 233 |
frames_data = item["frames"]
|
|
|
|
| 238 |
video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
|
| 239 |
else:
|
| 240 |
video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
|
| 241 |
+
|
| 242 |
+
task = item.get("task", "Complete the task")
|
| 243 |
+
quality_label = item.get("quality_label", None)
|
| 244 |
+
partial_success = item.get("partial_success", None)
|
| 245 |
+
|
| 246 |
+
return video_path, task, quality_label, partial_success
|
| 247 |
else:
|
| 248 |
+
return None, None, None, None
|
| 249 |
except Exception as e:
|
| 250 |
logger.error(f"Error getting trajectory video path: {e}")
|
| 251 |
+
return None, None, None, None
|
| 252 |
|
| 253 |
|
| 254 |
+
def extract_frames(video_path: str, fps: float = 1.0) -> np.ndarray:
|
| 255 |
"""Extract frames from video file as numpy array (T, H, W, C).
|
| 256 |
+
|
| 257 |
Supports both local file paths and URLs (e.g., HuggingFace Hub URLs).
|
| 258 |
+
Uses the provided ``fps`` to control how densely frames are sampled from
|
| 259 |
+
the underlying video; there is no additional hard cap on the number of frames.
|
| 260 |
"""
|
| 261 |
if video_path is None:
|
| 262 |
return None
|
|
|
|
| 277 |
vr = decord.VideoReader(video_path, num_threads=1)
|
| 278 |
total_frames = len(vr)
|
| 279 |
|
| 280 |
+
# Determine native FPS; fall back to a reasonable default if unavailable
|
| 281 |
+
try:
|
| 282 |
+
native_fps = float(vr.get_avg_fps())
|
| 283 |
+
except Exception:
|
| 284 |
+
native_fps = 1.0
|
| 285 |
+
|
| 286 |
+
# If user-specified fps is invalid or None, default to native fps
|
| 287 |
+
if fps is None or fps <= 0:
|
| 288 |
+
fps = native_fps
|
| 289 |
+
|
| 290 |
+
# Compute how many frames we want based on desired fps
|
| 291 |
+
# num_frames ≈ total_duration * fps = total_frames * (fps / native_fps)
|
| 292 |
+
if native_fps > 0:
|
| 293 |
+
desired_frames = int(round(total_frames * (fps / native_fps)))
|
| 294 |
+
else:
|
| 295 |
+
desired_frames = total_frames
|
| 296 |
+
|
| 297 |
+
# Clamp to [1, total_frames]
|
| 298 |
+
desired_frames = max(1, min(desired_frames, total_frames))
|
| 299 |
+
|
| 300 |
+
# Evenly sample indices to match the desired number of frames
|
| 301 |
+
if desired_frames == total_frames:
|
| 302 |
frame_indices = list(range(total_frames))
|
| 303 |
else:
|
| 304 |
+
frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
|
| 307 |
del vr
|
|
|
|
| 328 |
return None, None, "Please provide a video."
|
| 329 |
|
| 330 |
try:
|
| 331 |
+
frames_array = extract_frames(video_path, fps=fps)
|
| 332 |
if frames_array is None or frames_array.size == 0:
|
| 333 |
return None, None, "Could not extract frames from video."
|
| 334 |
|
|
|
|
| 406 |
return "Please provide both videos.", None
|
| 407 |
|
| 408 |
try:
|
| 409 |
+
frames_array_a = extract_frames(video_a_path, fps=fps)
|
| 410 |
+
frames_array_b = extract_frames(video_b_path, fps=fps)
|
| 411 |
|
| 412 |
if frames_array_a is None or frames_array_a.size == 0:
|
| 413 |
return "Could not extract frames from video A.", None
|
|
|
|
| 508 |
ax.set_ylabel('Progress (0-1)', fontsize=18, fontweight='bold')
|
| 509 |
ax.set_title('Progress Prediction', fontsize=20, fontweight='bold')
|
| 510 |
ax.set_ylim([0, 1])
|
|
|
|
| 511 |
|
| 512 |
plt.tight_layout()
|
| 513 |
|
|
|
|
| 538 |
ax.set_ylabel('Success Probability (0-1)', fontsize=18, fontweight='bold')
|
| 539 |
ax.set_title('Success Prediction', fontsize=20, fontweight='bold')
|
| 540 |
ax.set_ylim([0, 1])
|
|
|
|
| 541 |
|
| 542 |
plt.tight_layout()
|
| 543 |
|
|
|
|
| 672 |
load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm")
|
| 673 |
|
| 674 |
dataset_status_single = gr.Markdown("", visible=False)
|
| 675 |
+
with gr.Row():
|
| 676 |
+
prev_traj_btn = gr.Button("⬅️ Prev", variant="secondary", size="sm")
|
| 677 |
+
trajectory_slider = gr.Slider(
|
| 678 |
+
minimum=0,
|
| 679 |
+
maximum=0,
|
| 680 |
+
step=1,
|
| 681 |
+
value=0,
|
| 682 |
+
label="Trajectory Index",
|
| 683 |
+
interactive=True
|
| 684 |
+
)
|
| 685 |
+
next_traj_btn = gr.Button("Next ➡️", variant="secondary", size="sm")
|
| 686 |
+
trajectory_metadata = gr.Markdown("", visible=False)
|
| 687 |
use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary")
|
| 688 |
|
| 689 |
gr.Markdown("---")
|
|
|
|
| 744 |
def use_dataset_video(dataset, index, dataset_name):
|
| 745 |
"""Load video from dataset and update inputs."""
|
| 746 |
if dataset is None:
|
| 747 |
+
return None, "Complete the task", gr.update(value="No dataset loaded", visible=True), gr.update(visible=False)
|
| 748 |
+
|
| 749 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
|
| 750 |
+
if video_path:
|
| 751 |
+
# Build metadata text
|
| 752 |
+
metadata_lines = []
|
| 753 |
+
if quality_label:
|
| 754 |
+
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 755 |
+
if partial_success is not None:
|
| 756 |
+
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 757 |
+
|
| 758 |
+
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 759 |
+
status_text = f"✅ Loaded trajectory {index} from dataset"
|
| 760 |
+
if metadata_text:
|
| 761 |
+
status_text += f"\n\n{metadata_text}"
|
| 762 |
+
|
| 763 |
+
return (
|
| 764 |
+
video_path,
|
| 765 |
+
task,
|
| 766 |
+
gr.update(value=status_text, visible=True),
|
| 767 |
+
gr.update(value=metadata_text, visible=bool(metadata_text))
|
| 768 |
+
)
|
| 769 |
+
else:
|
| 770 |
+
return None, "Complete the task", gr.update(value="❌ Error loading trajectory", visible=True), gr.update(visible=False)
|
| 771 |
+
|
| 772 |
+
def next_trajectory(dataset, current_idx, dataset_name):
|
| 773 |
+
"""Go to next trajectory."""
|
| 774 |
+
if dataset is None:
|
| 775 |
+
return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 776 |
+
next_idx = min(current_idx + 1, len(dataset) - 1)
|
| 777 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, next_idx, dataset_name)
|
| 778 |
|
|
|
|
| 779 |
if video_path:
|
| 780 |
+
# Build metadata text
|
| 781 |
+
metadata_lines = []
|
| 782 |
+
if quality_label:
|
| 783 |
+
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 784 |
+
if partial_success is not None:
|
| 785 |
+
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 786 |
+
|
| 787 |
+
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 788 |
+
return (
|
| 789 |
+
next_idx,
|
| 790 |
+
video_path,
|
| 791 |
+
task,
|
| 792 |
+
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 793 |
+
gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True)
|
| 794 |
+
)
|
| 795 |
else:
|
| 796 |
+
return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 797 |
+
|
| 798 |
+
def prev_trajectory(dataset, current_idx, dataset_name):
|
| 799 |
+
"""Go to previous trajectory."""
|
| 800 |
+
if dataset is None:
|
| 801 |
+
return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 802 |
+
prev_idx = max(current_idx - 1, 0)
|
| 803 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, prev_idx, dataset_name)
|
| 804 |
+
|
| 805 |
+
if video_path:
|
| 806 |
+
# Build metadata text
|
| 807 |
+
metadata_lines = []
|
| 808 |
+
if quality_label:
|
| 809 |
+
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 810 |
+
if partial_success is not None:
|
| 811 |
+
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 812 |
+
|
| 813 |
+
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 814 |
+
return (
|
| 815 |
+
prev_idx,
|
| 816 |
+
video_path,
|
| 817 |
+
task,
|
| 818 |
+
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 819 |
+
gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True)
|
| 820 |
+
)
|
| 821 |
+
else:
|
| 822 |
+
return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
|
| 823 |
+
|
| 824 |
+
def update_trajectory_on_slider_change(dataset, index, dataset_name):
|
| 825 |
+
"""Update trajectory metadata when slider changes."""
|
| 826 |
+
if dataset is None:
|
| 827 |
+
return gr.update(visible=False), gr.update(visible=False)
|
| 828 |
+
|
| 829 |
+
video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
|
| 830 |
+
if video_path:
|
| 831 |
+
# Build metadata text
|
| 832 |
+
metadata_lines = []
|
| 833 |
+
if quality_label:
|
| 834 |
+
metadata_lines.append(f"**Quality Label:** {quality_label}")
|
| 835 |
+
if partial_success is not None:
|
| 836 |
+
metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
|
| 837 |
+
|
| 838 |
+
metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
|
| 839 |
+
return (
|
| 840 |
+
gr.update(value=metadata_text, visible=bool(metadata_text)),
|
| 841 |
+
gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True)
|
| 842 |
+
)
|
| 843 |
+
else:
|
| 844 |
+
return gr.update(visible=False), gr.update(visible=False)
|
| 845 |
|
| 846 |
# Dataset selection handlers
|
| 847 |
dataset_name_single.change(
|
|
|
|
| 865 |
use_dataset_video_btn.click(
|
| 866 |
fn=use_dataset_video,
|
| 867 |
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 868 |
+
outputs=[single_video_input, task_text_input, dataset_status_single, trajectory_metadata]
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
# Navigation buttons
|
| 872 |
+
next_traj_btn.click(
|
| 873 |
+
fn=next_trajectory,
|
| 874 |
+
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 875 |
+
outputs=[trajectory_slider, single_video_input, task_text_input, trajectory_metadata, dataset_status_single]
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
prev_traj_btn.click(
|
| 879 |
+
fn=prev_trajectory,
|
| 880 |
+
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 881 |
+
outputs=[trajectory_slider, single_video_input, task_text_input, trajectory_metadata, dataset_status_single]
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
# Update metadata when slider changes
|
| 885 |
+
trajectory_slider.change(
|
| 886 |
+
fn=update_trajectory_on_slider_change,
|
| 887 |
+
inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
|
| 888 |
+
outputs=[trajectory_metadata, dataset_status_single]
|
| 889 |
)
|
| 890 |
|
| 891 |
analyze_single_btn.click(
|