Anthony Liang commited on
Commit
f506da8
·
1 Parent(s): 6cf09b8

small ui updates

Browse files
Files changed (1) hide show
  1. app.py +167 -29
app.py CHANGED
@@ -227,7 +227,7 @@ def get_available_configs(dataset_name):
227
 
228
 
229
  def get_trajectory_video_path(dataset, index, dataset_name):
230
- """Get video path from a trajectory in the dataset."""
231
  try:
232
  item = dataset[int(index)]
233
  frames_data = item["frames"]
@@ -238,18 +238,25 @@ def get_trajectory_video_path(dataset, index, dataset_name):
238
  video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
239
  else:
240
  video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
241
- return video_path, item.get("task", "Complete the task")
 
 
 
 
 
242
  else:
243
- return None, None
244
  except Exception as e:
245
  logger.error(f"Error getting trajectory video path: {e}")
246
- return None, None
247
 
248
 
249
- def extract_frames(video_path: str, max_frames: int = 16, fps: float = 1.0) -> np.ndarray:
250
  """Extract frames from video file as numpy array (T, H, W, C).
251
-
252
  Supports both local file paths and URLs (e.g., HuggingFace Hub URLs).
 
 
253
  """
254
  if video_path is None:
255
  return None
@@ -270,13 +277,31 @@ def extract_frames(video_path: str, max_frames: int = 16, fps: float = 1.0) -> n
270
  vr = decord.VideoReader(video_path, num_threads=1)
271
  total_frames = len(vr)
272
 
273
- if total_frames <= max_frames:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  frame_indices = list(range(total_frames))
275
  else:
276
- frame_indices = [
277
- int(i * total_frames / max_frames)
278
- for i in range(max_frames)
279
- ]
280
 
281
  frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
282
  del vr
@@ -303,7 +328,7 @@ def process_single_video(
303
  return None, None, "Please provide a video."
304
 
305
  try:
306
- frames_array = extract_frames(video_path, max_frames=16, fps=fps)
307
  if frames_array is None or frames_array.size == 0:
308
  return None, None, "Could not extract frames from video."
309
 
@@ -381,8 +406,8 @@ def process_dual_videos(
381
  return "Please provide both videos.", None
382
 
383
  try:
384
- frames_array_a = extract_frames(video_a_path, max_frames=16, fps=fps)
385
- frames_array_b = extract_frames(video_b_path, max_frames=16, fps=fps)
386
 
387
  if frames_array_a is None or frames_array_a.size == 0:
388
  return "Could not extract frames from video A.", None
@@ -483,7 +508,6 @@ def create_progress_plot(progress_pred: np.ndarray, num_frames: int) -> str:
483
  ax.set_ylabel('Progress (0-1)', fontsize=18, fontweight='bold')
484
  ax.set_title('Progress Prediction', fontsize=20, fontweight='bold')
485
  ax.set_ylim([0, 1])
486
- ax.legend(fontsize=14)
487
 
488
  plt.tight_layout()
489
 
@@ -514,7 +538,6 @@ def create_success_plot(success_probs: np.ndarray, num_frames: int) -> str:
514
  ax.set_ylabel('Success Probability (0-1)', fontsize=18, fontweight='bold')
515
  ax.set_title('Success Prediction', fontsize=20, fontweight='bold')
516
  ax.set_ylim([0, 1])
517
- ax.legend(fontsize=14)
518
 
519
  plt.tight_layout()
520
 
@@ -649,14 +672,18 @@ with demo:
649
  load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm")
650
 
651
  dataset_status_single = gr.Markdown("", visible=False)
652
- trajectory_slider = gr.Slider(
653
- minimum=0,
654
- maximum=0,
655
- step=1,
656
- value=0,
657
- label="Trajectory Index",
658
- interactive=False
659
- )
 
 
 
 
660
  use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary")
661
 
662
  gr.Markdown("---")
@@ -717,13 +744,104 @@ with demo:
717
  def use_dataset_video(dataset, index, dataset_name):
718
  """Load video from dataset and update inputs."""
719
  if dataset is None:
720
- return None, "Complete the task", gr.update(value="No dataset loaded", visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
 
722
- video_path, task = get_trajectory_video_path(dataset, index, dataset_name)
723
  if video_path:
724
- return video_path, task, gr.update(value=f"✅ Loaded trajectory {index} from dataset", visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725
  else:
726
- return None, "Complete the task", gr.update(value="❌ Error loading trajectory", visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
 
728
  # Dataset selection handlers
729
  dataset_name_single.change(
@@ -747,7 +865,27 @@ with demo:
747
  use_dataset_video_btn.click(
748
  fn=use_dataset_video,
749
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
750
- outputs=[single_video_input, task_text_input, dataset_status_single]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  )
752
 
753
  analyze_single_btn.click(
 
227
 
228
 
229
  def get_trajectory_video_path(dataset, index, dataset_name):
230
+ """Get video path and metadata from a trajectory in the dataset."""
231
  try:
232
  item = dataset[int(index)]
233
  frames_data = item["frames"]
 
238
  video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
239
  else:
240
  video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
241
+
242
+ task = item.get("task", "Complete the task")
243
+ quality_label = item.get("quality_label", None)
244
+ partial_success = item.get("partial_success", None)
245
+
246
+ return video_path, task, quality_label, partial_success
247
  else:
248
+ return None, None, None, None
249
  except Exception as e:
250
  logger.error(f"Error getting trajectory video path: {e}")
251
+ return None, None, None, None
252
 
253
 
254
+ def extract_frames(video_path: str, fps: float = 1.0) -> np.ndarray:
255
  """Extract frames from video file as numpy array (T, H, W, C).
256
+
257
  Supports both local file paths and URLs (e.g., HuggingFace Hub URLs).
258
+ Uses the provided ``fps`` to control how densely frames are sampled from
259
+ the underlying video; there is no additional hard cap on the number of frames.
260
  """
261
  if video_path is None:
262
  return None
 
277
  vr = decord.VideoReader(video_path, num_threads=1)
278
  total_frames = len(vr)
279
 
280
+ # Determine native FPS; fall back to a reasonable default if unavailable
281
+ try:
282
+ native_fps = float(vr.get_avg_fps())
283
+ except Exception:
284
+ native_fps = 1.0
285
+
286
+ # If user-specified fps is invalid or None, default to native fps
287
+ if fps is None or fps <= 0:
288
+ fps = native_fps
289
+
290
+ # Compute how many frames we want based on desired fps
291
+ # num_frames ≈ total_duration * fps = total_frames * (fps / native_fps)
292
+ if native_fps > 0:
293
+ desired_frames = int(round(total_frames * (fps / native_fps)))
294
+ else:
295
+ desired_frames = total_frames
296
+
297
+ # Clamp to [1, total_frames]
298
+ desired_frames = max(1, min(desired_frames, total_frames))
299
+
300
+ # Evenly sample indices to match the desired number of frames
301
+ if desired_frames == total_frames:
302
  frame_indices = list(range(total_frames))
303
  else:
304
+ frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
 
 
 
305
 
306
  frames_array = vr.get_batch(frame_indices).asnumpy() # Shape: (T, H, W, C)
307
  del vr
 
328
  return None, None, "Please provide a video."
329
 
330
  try:
331
+ frames_array = extract_frames(video_path, fps=fps)
332
  if frames_array is None or frames_array.size == 0:
333
  return None, None, "Could not extract frames from video."
334
 
 
406
  return "Please provide both videos.", None
407
 
408
  try:
409
+ frames_array_a = extract_frames(video_a_path, fps=fps)
410
+ frames_array_b = extract_frames(video_b_path, fps=fps)
411
 
412
  if frames_array_a is None or frames_array_a.size == 0:
413
  return "Could not extract frames from video A.", None
 
508
  ax.set_ylabel('Progress (0-1)', fontsize=18, fontweight='bold')
509
  ax.set_title('Progress Prediction', fontsize=20, fontweight='bold')
510
  ax.set_ylim([0, 1])
 
511
 
512
  plt.tight_layout()
513
 
 
538
  ax.set_ylabel('Success Probability (0-1)', fontsize=18, fontweight='bold')
539
  ax.set_title('Success Prediction', fontsize=20, fontweight='bold')
540
  ax.set_ylim([0, 1])
 
541
 
542
  plt.tight_layout()
543
 
 
672
  load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm")
673
 
674
  dataset_status_single = gr.Markdown("", visible=False)
675
+ with gr.Row():
676
+ prev_traj_btn = gr.Button("⬅️ Prev", variant="secondary", size="sm")
677
+ trajectory_slider = gr.Slider(
678
+ minimum=0,
679
+ maximum=0,
680
+ step=1,
681
+ value=0,
682
+ label="Trajectory Index",
683
+ interactive=True
684
+ )
685
+ next_traj_btn = gr.Button("Next ➡️", variant="secondary", size="sm")
686
+ trajectory_metadata = gr.Markdown("", visible=False)
687
  use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary")
688
 
689
  gr.Markdown("---")
 
744
  def use_dataset_video(dataset, index, dataset_name):
745
  """Load video from dataset and update inputs."""
746
  if dataset is None:
747
+ return None, "Complete the task", gr.update(value="No dataset loaded", visible=True), gr.update(visible=False)
748
+
749
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
750
+ if video_path:
751
+ # Build metadata text
752
+ metadata_lines = []
753
+ if quality_label:
754
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
755
+ if partial_success is not None:
756
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
757
+
758
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
759
+ status_text = f"✅ Loaded trajectory {index} from dataset"
760
+ if metadata_text:
761
+ status_text += f"\n\n{metadata_text}"
762
+
763
+ return (
764
+ video_path,
765
+ task,
766
+ gr.update(value=status_text, visible=True),
767
+ gr.update(value=metadata_text, visible=bool(metadata_text))
768
+ )
769
+ else:
770
+ return None, "Complete the task", gr.update(value="❌ Error loading trajectory", visible=True), gr.update(visible=False)
771
+
772
+ def next_trajectory(dataset, current_idx, dataset_name):
773
+ """Go to next trajectory."""
774
+ if dataset is None:
775
+ return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
776
+ next_idx = min(current_idx + 1, len(dataset) - 1)
777
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, next_idx, dataset_name)
778
 
 
779
  if video_path:
780
+ # Build metadata text
781
+ metadata_lines = []
782
+ if quality_label:
783
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
784
+ if partial_success is not None:
785
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
786
+
787
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
788
+ return (
789
+ next_idx,
790
+ video_path,
791
+ task,
792
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
793
+ gr.update(value=f"✅ Trajectory {next_idx}/{len(dataset) - 1}", visible=True)
794
+ )
795
  else:
796
+ return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
797
+
798
+ def prev_trajectory(dataset, current_idx, dataset_name):
799
+ """Go to previous trajectory."""
800
+ if dataset is None:
801
+ return 0, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
802
+ prev_idx = max(current_idx - 1, 0)
803
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, prev_idx, dataset_name)
804
+
805
+ if video_path:
806
+ # Build metadata text
807
+ metadata_lines = []
808
+ if quality_label:
809
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
810
+ if partial_success is not None:
811
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
812
+
813
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
814
+ return (
815
+ prev_idx,
816
+ video_path,
817
+ task,
818
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
819
+ gr.update(value=f"✅ Trajectory {prev_idx}/{len(dataset) - 1}", visible=True)
820
+ )
821
+ else:
822
+ return current_idx, None, "Complete the task", gr.update(visible=False), gr.update(visible=False)
823
+
824
+ def update_trajectory_on_slider_change(dataset, index, dataset_name):
825
+ """Update trajectory metadata when slider changes."""
826
+ if dataset is None:
827
+ return gr.update(visible=False), gr.update(visible=False)
828
+
829
+ video_path, task, quality_label, partial_success = get_trajectory_video_path(dataset, index, dataset_name)
830
+ if video_path:
831
+ # Build metadata text
832
+ metadata_lines = []
833
+ if quality_label:
834
+ metadata_lines.append(f"**Quality Label:** {quality_label}")
835
+ if partial_success is not None:
836
+ metadata_lines.append(f"**Partial Success:** {partial_success:.3f}")
837
+
838
+ metadata_text = "\n".join(metadata_lines) if metadata_lines else ""
839
+ return (
840
+ gr.update(value=metadata_text, visible=bool(metadata_text)),
841
+ gr.update(value=f"Trajectory {index}/{len(dataset) - 1}", visible=True)
842
+ )
843
+ else:
844
+ return gr.update(visible=False), gr.update(visible=False)
845
 
846
  # Dataset selection handlers
847
  dataset_name_single.change(
 
865
  use_dataset_video_btn.click(
866
  fn=use_dataset_video,
867
  inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
868
+ outputs=[single_video_input, task_text_input, dataset_status_single, trajectory_metadata]
869
+ )
870
+
871
+ # Navigation buttons
872
+ next_traj_btn.click(
873
+ fn=next_trajectory,
874
+ inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
875
+ outputs=[trajectory_slider, single_video_input, task_text_input, trajectory_metadata, dataset_status_single]
876
+ )
877
+
878
+ prev_traj_btn.click(
879
+ fn=prev_trajectory,
880
+ inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
881
+ outputs=[trajectory_slider, single_video_input, task_text_input, trajectory_metadata, dataset_status_single]
882
+ )
883
+
884
+ # Update metadata when slider changes
885
+ trajectory_slider.change(
886
+ fn=update_trajectory_on_slider_change,
887
+ inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
888
+ outputs=[trajectory_metadata, dataset_status_single]
889
  )
890
 
891
  analyze_single_btn.click(