Anthony Liang commited on
Commit
6cf09b8
·
1 Parent(s): 86399fe

select videos

Browse files
Files changed (1) hide show
  1. app.py +314 -11
app.py CHANGED
@@ -9,6 +9,7 @@ import os
9
  import tempfile
10
  from pathlib import Path
11
  from typing import Optional, Tuple
 
12
 
13
  import gradio as gr
14
  import spaces # Required for ZeroGPU
@@ -23,16 +24,51 @@ from typing import Any, Optional, Tuple
23
 
24
  from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
25
  from rfm.evals.eval_utils import build_payload, post_batch_npy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Global server state
28
  _server_state = {
29
  "server_url": None,
30
  }
31
 
32
- def check_server_health(server_url: str) -> Tuple[str, Optional[dict]]:
33
  """Check server health and get model info."""
34
  if not server_url:
35
- return "Please provide a server URL.", None
36
 
37
  try:
38
  url = server_url.rstrip("/") + "/health"
@@ -50,24 +86,187 @@ def check_server_health(server_url: str) -> Tuple[str, Optional[dict]]:
50
  except:
51
  pass
52
 
 
 
 
 
 
 
 
 
 
 
 
53
  _server_state["server_url"] = server_url
54
- return f"Server connected: {health_data.get('available_gpus', 0)}/{health_data.get('total_gpus', 0)} GPUs available", health_data
55
  except requests.exceptions.RequestException as e:
56
- return f"Error connecting to server: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def extract_frames(video_path: str, max_frames: int = 16, fps: float = 1.0) -> np.ndarray:
60
- """Extract frames from video file as numpy array (T, H, W, C)."""
 
 
 
61
  if video_path is None:
62
  return None
63
 
64
  if isinstance(video_path, tuple):
65
  video_path = video_path[0]
66
 
67
- if not os.path.exists(video_path):
 
 
 
 
 
68
  return None
69
 
70
  try:
 
71
  vr = decord.VideoReader(video_path, num_threads=1)
72
  total_frames = len(vr)
73
 
@@ -83,7 +282,7 @@ def extract_frames(video_path: str, max_frames: int = 16, fps: float = 1.0) -> n
83
  del vr
84
  return frames_array
85
  except Exception as e:
86
- print(f"Error extracting frames: {e}")
87
  return None
88
 
89
 
@@ -410,22 +609,60 @@ with demo:
410
  check_connection_btn = gr.Button("Check Connection", variant="primary", size="sm")
411
 
412
  server_status = gr.Markdown("Enter server URL and click 'Check Connection'")
 
413
 
414
  def on_check_connection(server_url: str):
415
  """Handle server connection check."""
416
- status, health_data = check_server_health(server_url)
417
- return status
 
 
 
418
 
419
  check_connection_btn.click(
420
  fn=on_check_connection,
421
  inputs=[server_url_input],
422
- outputs=[server_status],
423
  )
424
 
425
  with gr.Tab("Progress Prediction"):
426
  gr.Markdown("### Progress & Success Prediction")
 
 
427
  with gr.Row():
428
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  single_video_input = gr.Video(label="Upload Video", height=300)
430
  task_text_input = gr.Textbox(
431
  label="Task Description",
@@ -446,7 +683,73 @@ with demo:
446
  progress_plot = gr.Image(label="Progress Prediction", height=400)
447
  success_plot = gr.Image(label="Success Prediction", height=400)
448
  info_output = gr.Markdown("")
449
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  analyze_single_btn.click(
451
  fn=process_single_video,
452
  inputs=[single_video_input, task_text_input, server_url_input, fps_input_single],
 
9
  import tempfile
10
  from pathlib import Path
11
  from typing import Optional, Tuple
12
+ import logging
13
 
14
  import gradio as gr
15
  import spaces # Required for ZeroGPU
 
24
 
25
  from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample
26
  from rfm.evals.eval_utils import build_payload, post_batch_npy
27
+ from datasets import load_dataset as load_dataset_hf, get_dataset_config_names
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Predefined dataset names (same as visualizer)
32
+ PREDEFINED_DATASETS = [
33
+ "abraranwar/agibotworld_alpha_rfm",
34
+ "abraranwar/libero_rfm",
35
+ "abraranwar/usc_koch_rewind_rfm",
36
+ "aliangdw/metaworld",
37
+ "anqil/rh20t_rfm",
38
+ "anqil/rh20t_subset_rfm",
39
+ "jesbu1/auto_eval_rfm",
40
+ "jesbu1/egodex_rfm",
41
+ "jesbu1/epic_rfm",
42
+ "jesbu1/fino_net_rfm",
43
+ "jesbu1/failsafe_rfm",
44
+ "jesbu1/hand_paired_rfm",
45
+ "jesbu1/galaxea_rfm",
46
+ "jesbu1/h2r_rfm",
47
+ "jesbu1/humanoid_everyday_rfm",
48
+ "jesbu1/molmoact_rfm",
49
+ "jesbu1/motif_rfm",
50
+ "jesbu1/oxe_rfm",
51
+ "jesbu1/oxe_rfm_eval",
52
+ "jesbu1/ph2d_rfm",
53
+ "jesbu1/racer_rfm",
54
+ "jesbu1/roboarena_0825_rfm",
55
+ "jesbu1/soar_rfm",
56
+ "ykorkmaz/libero_failure_rfm",
57
+ "aliangdw/usc_xarm_policy_ranking",
58
+ "aliangdw/usc_franka_policy_ranking",
59
+ "aliangdw/utd_so101_policy_ranking",
60
+ "aliangdw/utd_so101_human"
61
+ ]
62
 
63
  # Global server state
64
  _server_state = {
65
  "server_url": None,
66
  }
67
 
68
+ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]:
69
  """Check server health and get model info."""
70
  if not server_url:
71
+ return "Please provide a server URL.", None, None
72
 
73
  try:
74
  url = server_url.rstrip("/") + "/health"
 
86
  except:
87
  pass
88
 
89
+ # Try to get model info
90
+ model_info_text = None
91
+ try:
92
+ model_info_url = server_url.rstrip("/") + "/model_info"
93
+ model_info_response = requests.get(model_info_url, timeout=5.0)
94
+ if model_info_response.status_code == 200:
95
+ model_info_data = model_info_response.json()
96
+ model_info_text = format_model_info(model_info_data)
97
+ except Exception as e:
98
+ logger.warning(f"Could not fetch model info: {e}")
99
+
100
  _server_state["server_url"] = server_url
101
+ return f"Server connected: {health_data.get('available_gpus', 0)}/{health_data.get('total_gpus', 0)} GPUs available", health_data, model_info_text
102
  except requests.exceptions.RequestException as e:
103
+ return f"Error connecting to server: {str(e)}", None, None
104
+
105
+
106
+ def format_model_info(model_info: dict) -> str:
107
+ """Format model info and experiment config as markdown."""
108
+ lines = ["## Model Information\n"]
109
+
110
+ # Model path
111
+ model_path = model_info.get("model_path", "Unknown")
112
+ lines.append(f"**Model Path:** `{model_path}`\n")
113
+
114
+ # Number of GPUs
115
+ num_gpus = model_info.get("num_gpus", "Unknown")
116
+ lines.append(f"**Number of GPUs:** {num_gpus}\n")
117
+
118
+ # Model architecture
119
+ model_arch = model_info.get("model_architecture", {})
120
+ if model_arch and "error" not in model_arch:
121
+ lines.append("\n## Model Architecture\n")
122
+
123
+ model_class = model_arch.get("model_class", "Unknown")
124
+ model_module = model_arch.get("model_module", "Unknown")
125
+ lines.append(f"- **Model Class:** `{model_class}`\n")
126
+ lines.append(f"- **Module:** `{model_module}`\n")
127
+
128
+ # Parameter counts
129
+ total_params = model_arch.get("total_parameters")
130
+ trainable_params = model_arch.get("trainable_parameters")
131
+ frozen_params = model_arch.get("frozen_parameters")
132
+ trainable_pct = model_arch.get("trainable_percentage")
133
+
134
+ if total_params is not None:
135
+ lines.append(f"\n### Parameter Statistics\n")
136
+ lines.append(f"- **Total Parameters:** {total_params:,}\n")
137
+ if trainable_params is not None:
138
+ lines.append(f"- **Trainable Parameters:** {trainable_params:,}\n")
139
+ if frozen_params is not None:
140
+ lines.append(f"- **Frozen Parameters:** {frozen_params:,}\n")
141
+ if trainable_pct is not None:
142
+ lines.append(f"- **Trainable Percentage:** {trainable_pct:.2f}%\n")
143
+
144
+ # Architecture summary
145
+ arch_summary = model_arch.get("architecture_summary", [])
146
+ if arch_summary:
147
+ lines.append(f"\n### Architecture Summary (Top-Level Modules)\n")
148
+ for module_info in arch_summary[:10]: # Show first 10 modules
149
+ name = module_info.get("name", "Unknown")
150
+ module_type = module_info.get("type", "Unknown")
151
+ params = module_info.get("parameters", 0)
152
+ lines.append(f"- **{name}** (`{module_type}`): {params:,} parameters\n")
153
+
154
+ # Experiment config
155
+ exp_config = model_info.get("experiment_config", {})
156
+ if exp_config:
157
+ lines.append("\n## Experiment Configuration\n")
158
+
159
+ # Model config
160
+ model_cfg = exp_config.get("model", {})
161
+ if model_cfg:
162
+ lines.append("### Model Configuration\n")
163
+ lines.append(f"- **Base Model:** `{model_cfg.get('base_model_id', 'N/A')}`\n")
164
+ lines.append(f"- **Model Type:** `{model_cfg.get('model_type', 'N/A')}`\n")
165
+ lines.append(f"- **Train Progress Head:** {model_cfg.get('train_progress_head', False)}\n")
166
+ lines.append(f"- **Train Preference Head:** {model_cfg.get('train_preference_head', False)}\n")
167
+ lines.append(f"- **Train Similarity Head:** {model_cfg.get('train_similarity_head', False)}\n")
168
+ lines.append(f"- **Train Success Head:** {model_cfg.get('train_success_head', False)}\n")
169
+ lines.append(f"- **Use PEFT:** {model_cfg.get('use_peft', False)}\n")
170
+ lines.append(f"- **Use Unsloth:** {model_cfg.get('use_unsloth', False)}\n")
171
+
172
+ # Data config
173
+ data_cfg = exp_config.get("data", {})
174
+ if data_cfg:
175
+ lines.append("\n### Data Configuration\n")
176
+ lines.append(f"- **Max Frames:** {data_cfg.get('max_frames', 'N/A')}\n")
177
+ lines.append(f"- **Resized Dimensions:** {data_cfg.get('resized_height', 'N/A')}x{data_cfg.get('resized_width', 'N/A')}\n")
178
+ train_datasets = data_cfg.get('train_datasets', [])
179
+ if train_datasets:
180
+ lines.append(f"- **Train Datasets:** {', '.join(train_datasets)}\n")
181
+ eval_datasets = data_cfg.get('eval_datasets', [])
182
+ if eval_datasets:
183
+ lines.append(f"- **Eval Datasets:** {', '.join(eval_datasets)}\n")
184
+
185
+ # Training config
186
+ training_cfg = exp_config.get("training", {})
187
+ if training_cfg:
188
+ lines.append("\n### Training Configuration\n")
189
+ lines.append(f"- **Learning Rate:** {training_cfg.get('learning_rate', 'N/A')}\n")
190
+ lines.append(f"- **Batch Size:** {training_cfg.get('per_device_train_batch_size', 'N/A')}\n")
191
+ lines.append(f"- **Gradient Accumulation Steps:** {training_cfg.get('gradient_accumulation_steps', 'N/A')}\n")
192
+ lines.append(f"- **Max Steps:** {training_cfg.get('max_steps', 'N/A')}\n")
193
+
194
+ return "".join(lines)
195
+
196
+
197
+ def load_rfm_dataset(dataset_name, config_name):
198
+ """Load the RFM dataset from HuggingFace Hub."""
199
+ try:
200
+ if not dataset_name or not config_name:
201
+ return None, "Please provide both dataset name and configuration"
202
+
203
+ dataset = load_dataset_hf(dataset_name, name=config_name, split="train")
204
+
205
+ if len(dataset) == 0:
206
+ return None, f"Dataset {dataset_name}/{config_name} is empty"
207
+
208
+ return dataset, f"Loaded {len(dataset)} trajectories from {dataset_name}/{config_name}"
209
+ except Exception as e:
210
+ error_msg = str(e)
211
+ if "not found" in error_msg.lower():
212
+ return None, f"Dataset or configuration not found: {dataset_name}/{config_name}"
213
+ elif "authentication" in error_msg.lower():
214
+ return None, f"Authentication required for {dataset_name}"
215
+ else:
216
+ return None, f"Error loading dataset: {error_msg}"
217
+
218
+
219
+ def get_available_configs(dataset_name):
220
+ """Get available configurations for a dataset."""
221
+ try:
222
+ configs = get_dataset_config_names(dataset_name)
223
+ return configs
224
+ except Exception as e:
225
+ logger.warning(f"Error getting configs for {dataset_name}: {e}")
226
+ return []
227
+
228
+
229
+ def get_trajectory_video_path(dataset, index, dataset_name):
230
+ """Get video path from a trajectory in the dataset."""
231
+ try:
232
+ item = dataset[int(index)]
233
+ frames_data = item["frames"]
234
+
235
+ if isinstance(frames_data, str):
236
+ # Construct HuggingFace Hub URL
237
+ if dataset_name:
238
+ video_path = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/{frames_data}"
239
+ else:
240
+ video_path = f"https://huggingface.co/datasets/aliangdw/rfm/resolve/main/{frames_data}"
241
+ return video_path, item.get("task", "Complete the task")
242
+ else:
243
+ return None, None
244
+ except Exception as e:
245
+ logger.error(f"Error getting trajectory video path: {e}")
246
+ return None, None
247
 
248
 
249
  def extract_frames(video_path: str, max_frames: int = 16, fps: float = 1.0) -> np.ndarray:
250
+ """Extract frames from video file as numpy array (T, H, W, C).
251
+
252
+ Supports both local file paths and URLs (e.g., HuggingFace Hub URLs).
253
+ """
254
  if video_path is None:
255
  return None
256
 
257
  if isinstance(video_path, tuple):
258
  video_path = video_path[0]
259
 
260
+ # Check if it's a URL or local file
261
+ is_url = video_path.startswith(("http://", "https://"))
262
+ is_local_file = os.path.exists(video_path) if not is_url else False
263
+
264
+ if not is_url and not is_local_file:
265
+ logger.warning(f"Video path does not exist: {video_path}")
266
  return None
267
 
268
  try:
269
+ # decord.VideoReader can handle both local files and URLs
270
  vr = decord.VideoReader(video_path, num_threads=1)
271
  total_frames = len(vr)
272
 
 
282
  del vr
283
  return frames_array
284
  except Exception as e:
285
+ logger.error(f"Error extracting frames from {video_path}: {e}")
286
  return None
287
 
288
 
 
609
  check_connection_btn = gr.Button("Check Connection", variant="primary", size="sm")
610
 
611
  server_status = gr.Markdown("Enter server URL and click 'Check Connection'")
612
+ model_info_display = gr.Markdown("", visible=False)
613
 
614
  def on_check_connection(server_url: str):
615
  """Handle server connection check."""
616
+ status, health_data, model_info_text = check_server_health(server_url)
617
+ if model_info_text:
618
+ return status, gr.update(value=model_info_text, visible=True)
619
+ else:
620
+ return status, gr.update(visible=False)
621
 
622
  check_connection_btn.click(
623
  fn=on_check_connection,
624
  inputs=[server_url_input],
625
+ outputs=[server_status, model_info_display],
626
  )
627
 
628
  with gr.Tab("Progress Prediction"):
629
  gr.Markdown("### Progress & Success Prediction")
630
+ gr.Markdown("Upload a video or select one from a dataset to get progress predictions.")
631
+
632
  with gr.Row():
633
  with gr.Column():
634
+ with gr.Accordion("📁 Select from Dataset", open=False):
635
+ dataset_name_single = gr.Dropdown(
636
+ choices=PREDEFINED_DATASETS,
637
+ value="jesbu1/oxe_rfm",
638
+ label="Dataset Name",
639
+ allow_custom_value=True
640
+ )
641
+ config_name_single = gr.Dropdown(
642
+ choices=[],
643
+ value="",
644
+ label="Configuration Name",
645
+ allow_custom_value=True
646
+ )
647
+ with gr.Row():
648
+ refresh_configs_btn = gr.Button("🔄 Refresh Configs", variant="secondary", size="sm")
649
+ load_dataset_btn = gr.Button("Load Dataset", variant="secondary", size="sm")
650
+
651
+ dataset_status_single = gr.Markdown("", visible=False)
652
+ trajectory_slider = gr.Slider(
653
+ minimum=0,
654
+ maximum=0,
655
+ step=1,
656
+ value=0,
657
+ label="Trajectory Index",
658
+ interactive=False
659
+ )
660
+ use_dataset_video_btn = gr.Button("Use Selected Video", variant="secondary")
661
+
662
+ gr.Markdown("---")
663
+ gr.Markdown("**OR**")
664
+ gr.Markdown("---")
665
+
666
  single_video_input = gr.Video(label="Upload Video", height=300)
667
  task_text_input = gr.Textbox(
668
  label="Task Description",
 
683
  progress_plot = gr.Image(label="Progress Prediction", height=400)
684
  success_plot = gr.Image(label="Success Prediction", height=400)
685
  info_output = gr.Markdown("")
686
+
687
+ # State variables for dataset
688
+ current_dataset_single = gr.State(None)
689
+
690
+ def update_config_choices_single(dataset_name):
691
+ """Update config choices when dataset changes."""
692
+ if not dataset_name:
693
+ return gr.update(choices=[], value="")
694
+ try:
695
+ configs = get_available_configs(dataset_name)
696
+ if configs:
697
+ return gr.update(choices=configs, value=configs[0])
698
+ else:
699
+ return gr.update(choices=[], value="")
700
+ except Exception as e:
701
+ logger.warning(f"Could not fetch configs: {e}")
702
+ return gr.update(choices=[], value="")
703
+
704
+ def load_dataset_single(dataset_name, config_name):
705
+ """Load dataset and update slider."""
706
+ dataset, status = load_rfm_dataset(dataset_name, config_name)
707
+ if dataset is not None:
708
+ max_index = len(dataset) - 1
709
+ return (
710
+ dataset,
711
+ gr.update(value=status, visible=True),
712
+ gr.update(maximum=max_index, value=0, interactive=True, label=f"Trajectory Index (0 to {max_index})")
713
+ )
714
+ else:
715
+ return None, gr.update(value=status, visible=True), gr.update(maximum=0, value=0, interactive=False)
716
+
717
+ def use_dataset_video(dataset, index, dataset_name):
718
+ """Load video from dataset and update inputs."""
719
+ if dataset is None:
720
+ return None, "Complete the task", gr.update(value="No dataset loaded", visible=True)
721
+
722
+ video_path, task = get_trajectory_video_path(dataset, index, dataset_name)
723
+ if video_path:
724
+ return video_path, task, gr.update(value=f"✅ Loaded trajectory {index} from dataset", visible=True)
725
+ else:
726
+ return None, "Complete the task", gr.update(value="❌ Error loading trajectory", visible=True)
727
+
728
+ # Dataset selection handlers
729
+ dataset_name_single.change(
730
+ fn=update_config_choices_single,
731
+ inputs=[dataset_name_single],
732
+ outputs=[config_name_single]
733
+ )
734
+
735
+ refresh_configs_btn.click(
736
+ fn=update_config_choices_single,
737
+ inputs=[dataset_name_single],
738
+ outputs=[config_name_single]
739
+ )
740
+
741
+ load_dataset_btn.click(
742
+ fn=load_dataset_single,
743
+ inputs=[dataset_name_single, config_name_single],
744
+ outputs=[current_dataset_single, dataset_status_single, trajectory_slider]
745
+ )
746
+
747
+ use_dataset_video_btn.click(
748
+ fn=use_dataset_video,
749
+ inputs=[current_dataset_single, trajectory_slider, dataset_name_single],
750
+ outputs=[single_video_input, task_text_input, dataset_status_single]
751
+ )
752
+
753
  analyze_single_btn.click(
754
  fn=process_single_video,
755
  inputs=[single_video_input, task_text_input, server_url_input, fps_input_single],