Anthony Liang commited on
Commit
c66a872
·
1 Parent(s): d4cfa7b

update app with model selection

Browse files
Files changed (2) hide show
  1. __pycache__/app.cpython-310.pyc +0 -0
  2. app.py +190 -49
__pycache__/app.cpython-310.pyc ADDED
Binary file (24.6 kB). View file
 
app.py CHANGED
@@ -23,7 +23,7 @@ matplotlib.use("Agg") # Use non-interactive backend
23
  import matplotlib.pyplot as plt
24
  import numpy as np
25
  import requests
26
- from typing import Any, Optional, Tuple
27
 
28
  from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample, SimilaritySample
29
  from rfm.evals.eval_utils import build_payload, post_batch_npy
@@ -62,14 +62,72 @@ PREDEFINED_DATASETS = [
62
  "aliangdw/usc_franka_policy_ranking",
63
  "aliangdw/utd_so101_policy_ranking",
64
  "aliangdw/utd_so101_human",
 
 
 
 
65
  ]
66
 
67
  # Global server state
68
  _server_state = {
69
  "server_url": None,
 
70
  }
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]:
74
  """Check server health and get model info."""
75
  if not server_url:
@@ -92,15 +150,7 @@ def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[
92
  pass
93
 
94
  # Try to get model info
95
- model_info_text = None
96
- try:
97
- model_info_url = server_url.rstrip("/") + "/model_info"
98
- model_info_response = requests.get(model_info_url, timeout=5.0)
99
- if model_info_response.status_code == 200:
100
- model_info_data = model_info_response.json()
101
- model_info_text = format_model_info(model_info_data)
102
- except Exception as e:
103
- logger.warning(f"Could not fetch model info: {e}")
104
 
105
  _server_state["server_url"] = server_url
106
  return (
@@ -271,11 +321,12 @@ def process_single_video(
271
  fps: float = 1.0,
272
  ) -> Tuple[Optional[str], Optional[str]]:
273
  """Process single video for progress and success predictions using eval server."""
 
274
  if not server_url:
275
- return None, "Please provide a server URL and check connection first."
276
-
277
- if not _server_state.get("server_url"):
278
- return None, "Server not connected. Please check server connection first."
279
 
280
  if video_path is None:
281
  return None, "Please provide a video."
@@ -377,11 +428,12 @@ def process_two_videos(
377
  fps: float = 1.0,
378
  ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
379
  """Process two videos for preference, similarity, or progress prediction using eval server."""
 
380
  if not server_url:
381
- return "Please provide a server URL and check connection first.", None, None
382
-
383
- if not _server_state.get("server_url"):
384
- return "Server not connected. Please check server connection first.", None, None
385
 
386
  if video_a_path is None or video_b_path is None:
387
  return "Please provide both videos.", None, None
@@ -558,41 +610,130 @@ with demo:
558
  """
559
  # RFM (Reward Foundation Model) Evaluation Server
560
 
561
- **Note:** This app connects to an eval server. Please provide the server URL and check connection before use.
562
  """
563
  )
564
 
565
- with gr.Tab("Server Setup"):
566
- gr.Markdown("### Connect to Eval Server")
567
- gr.Markdown("Enter the eval server URL and check connection.")
568
-
569
- with gr.Row():
570
- with gr.Column(scale=3):
571
- server_url_input = gr.Textbox(
572
- label="Server URL",
573
- placeholder="http://40.119.56.66:8000",
574
- value="http://40.119.56.66:8000",
575
- interactive=True,
576
- )
577
- with gr.Column(scale=1):
578
- check_connection_btn = gr.Button("Check Connection", variant="primary", size="sm")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
580
- server_status = gr.Markdown("Enter server URL and click 'Check Connection'")
581
- model_info_display = gr.Markdown("", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
- def on_check_connection(server_url: str):
584
- """Handle server connection check."""
585
- status, health_data, model_info_text = check_server_health(server_url)
586
- if model_info_text:
587
- return status, gr.update(value=model_info_text, visible=True)
588
- else:
589
- return status, gr.update(visible=False)
590
 
591
- check_connection_btn.click(
592
- fn=on_check_connection,
593
- inputs=[server_url_input],
594
- outputs=[server_status, model_info_display],
595
- )
596
 
597
  with gr.Tab("Progress Prediction"):
598
  gr.Markdown("### Progress & Success Prediction")
@@ -851,7 +992,7 @@ with demo:
851
 
852
  analyze_single_btn.click(
853
  fn=process_single_video,
854
- inputs=[single_video_input, task_text_input, server_url_input, fps_input_single],
855
  outputs=[progress_plot, info_output],
856
  api_name="process_single_video",
857
  )
@@ -1329,7 +1470,7 @@ with demo:
1329
 
1330
  analyze_dual_btn.click(
1331
  fn=process_two_videos,
1332
- inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_input, fps_input_dual],
1333
  outputs=[result_text, video_a_display, video_b_display],
1334
  api_name="process_two_videos",
1335
  )
 
23
  import matplotlib.pyplot as plt
24
  import numpy as np
25
  import requests
26
+ from typing import Any, List, Optional, Tuple
27
 
28
  from rfm.data.dataset_types import Trajectory, ProgressSample, PreferenceSample, SimilaritySample
29
  from rfm.evals.eval_utils import build_payload, post_batch_npy
 
62
  "aliangdw/usc_franka_policy_ranking",
63
  "aliangdw/utd_so101_policy_ranking",
64
  "aliangdw/utd_so101_human",
65
+ "jesbu1/utd_so101_clean_policy_ranking_top",
66
+ "jesbu1/utd_so101_clean_policy_ranking_wrist",
67
+ "jesbu1/mit_franka_p-rank_rfm",
68
+ "jesbu1/usc_koch_p_ranking_rfm",
69
  ]
70
 
71
  # Global server state
72
  _server_state = {
73
  "server_url": None,
74
+ "base_url": "http://40.119.56.66", # Default base URL
75
  }
76
 
77
 
78
+ def discover_available_models(base_url: str = "http://40.119.56.66", port_range: tuple = (8000, 8010)) -> List[Tuple[str, str]]:
79
+ """Discover available models by pinging ports in the specified range.
80
+
81
+ Returns:
82
+ List of tuples: [(server_url, model_name), ...]
83
+ """
84
+ available_models = []
85
+ start_port, end_port = port_range
86
+
87
+ for port in range(start_port, end_port + 1):
88
+ server_url = f"{base_url.rstrip('/')}:{port}"
89
+ try:
90
+ # Check health endpoint
91
+ health_url = f"{server_url}/health"
92
+ health_response = requests.get(health_url, timeout=2.0)
93
+ if health_response.status_code == 200:
94
+ # Try to get model info for model name
95
+ try:
96
+ model_info_url = f"{server_url}/model_info"
97
+ model_info_response = requests.get(model_info_url, timeout=2.0)
98
+ if model_info_response.status_code == 200:
99
+ model_info_data = model_info_response.json()
100
+ model_name = model_info_data.get("model_path", f"Model on port {port}")
101
+ available_models.append((server_url, model_name))
102
+ else:
103
+ # Health check passed but no model info, use port as name
104
+ available_models.append((server_url, f"Model on port {port}"))
105
+ except:
106
+ # Health check passed but couldn't get model info
107
+ available_models.append((server_url, f"Model on port {port}"))
108
+ except requests.exceptions.RequestException:
109
+ # Port not available, continue
110
+ continue
111
+
112
+ return available_models
113
+
114
+
115
+ def get_model_info_for_url(server_url: str) -> Optional[str]:
116
+ """Get formatted model info for a given server URL."""
117
+ if not server_url:
118
+ return None
119
+
120
+ try:
121
+ model_info_url = server_url.rstrip("/") + "/model_info"
122
+ model_info_response = requests.get(model_info_url, timeout=5.0)
123
+ if model_info_response.status_code == 200:
124
+ model_info_data = model_info_response.json()
125
+ return format_model_info(model_info_data)
126
+ except Exception as e:
127
+ logger.warning(f"Could not fetch model info: {e}")
128
+ return None
129
+
130
+
131
  def check_server_health(server_url: str) -> Tuple[str, Optional[dict], Optional[str]]:
132
  """Check server health and get model info."""
133
  if not server_url:
 
150
  pass
151
 
152
  # Try to get model info
153
+ model_info_text = get_model_info_for_url(server_url)
 
 
 
 
 
 
 
 
154
 
155
  _server_state["server_url"] = server_url
156
  return (
 
321
  fps: float = 1.0,
322
  ) -> Tuple[Optional[str], Optional[str]]:
323
  """Process single video for progress and success predictions using eval server."""
324
+ # Get server URL from state if not provided
325
  if not server_url:
326
+ server_url = _server_state.get("server_url")
327
+
328
+ if not server_url:
329
+ return None, "Please select a model from the dropdown above and ensure it's connected."
330
 
331
  if video_path is None:
332
  return None, "Please provide a video."
 
428
  fps: float = 1.0,
429
  ) -> Tuple[Optional[str], Optional[str], Optional[str]]:
430
  """Process two videos for preference, similarity, or progress prediction using eval server."""
431
+ # Get server URL from state if not provided
432
  if not server_url:
433
+ server_url = _server_state.get("server_url")
434
+
435
+ if not server_url:
436
+ return "Please select a model from the dropdown above and ensure it's connected.", None, None
437
 
438
  if video_a_path is None or video_b_path is None:
439
  return "Please provide both videos.", None, None
 
610
  """
611
  # RFM (Reward Foundation Model) Evaluation Server
612
 
613
+ Select a model from the dropdown below. The app will automatically discover available models.
614
  """
615
  )
616
 
617
+ # Model selector at the top
618
+ with gr.Row():
619
+ with gr.Column(scale=4):
620
+ base_url_input = gr.Textbox(
621
+ label="Base Server URL",
622
+ placeholder="http://40.119.56.66",
623
+ value="http://40.119.56.66",
624
+ interactive=True,
625
+ )
626
+ model_dropdown = gr.Dropdown(
627
+ label="Select Model",
628
+ choices=[],
629
+ value=None,
630
+ interactive=True,
631
+ info="Click 'Discover Models' to find available models on ports 8000-8010",
632
+ )
633
+ with gr.Column(scale=1):
634
+ discover_btn = gr.Button("🔍 Discover Models", variant="primary", size="lg")
635
+
636
+ with gr.Row():
637
+ server_status = gr.Markdown("Click 'Discover Models' to find available models", visible=True)
638
+
639
+ with gr.Accordion("📋 Model Information", open=False) as model_info_accordion:
640
+ model_info_display = gr.Markdown("", visible=True)
641
+
642
+ # Hidden state to store server URL and model mapping
643
+ server_url_state = gr.State(value=None)
644
+ model_url_mapping_state = gr.State(value={}) # Maps model_name -> server_url
645
+
646
+ def discover_and_select_models(base_url: str):
647
+ """Discover models and update dropdown."""
648
+ if not base_url:
649
+ return (
650
+ gr.update(choices=[], value=None),
651
+ gr.update(value="Please provide a base URL", visible=True),
652
+ gr.update(value="", visible=True),
653
+ None,
654
+ {}, # Empty mapping
655
+ )
656
+
657
+ _server_state["base_url"] = base_url
658
+ models = discover_available_models(base_url, port_range=(8000, 8010))
659
+
660
+ if not models:
661
+ return (
662
+ gr.update(choices=[], value=None),
663
+ gr.update(value="❌ No models found on ports 8000-8010. Make sure servers are running.", visible=True),
664
+ gr.update(value="", visible=True),
665
+ None,
666
+ {}, # Empty mapping
667
+ )
668
+
669
+ # Format choices: show model_name in dropdown
670
+ # Store mapping of model_name to URL in state
671
+ choices = []
672
+ url_map = {}
673
+ for url, name in models:
674
+ choices.append(name)
675
+ url_map[name] = url
676
+
677
+ # Auto-select first model
678
+ selected_choice = choices[0] if choices else None
679
+ selected_url = url_map.get(selected_choice) if selected_choice else None
680
+
681
+ # Get model info for selected model
682
+ model_info_text = get_model_info_for_url(selected_url) if selected_url else ""
683
+ status_text = f"✅ Found {len(models)} model(s). Auto-selected first model."
684
+
685
+ _server_state["server_url"] = selected_url
686
+
687
+ return (
688
+ gr.update(choices=choices, value=selected_choice),
689
+ gr.update(value=status_text, visible=True),
690
+ gr.update(value=model_info_text, visible=True),
691
+ selected_url,
692
+ url_map, # Return mapping for state
693
+ )
694
 
695
+ def on_model_selected(model_choice: str, url_mapping: dict):
696
+ """Handle model selection change."""
697
+ if not model_choice:
698
+ return (
699
+ gr.update(value="No model selected", visible=True),
700
+ gr.update(value="", visible=True),
701
+ None,
702
+ )
703
+
704
+ # Get URL from mapping
705
+ server_url = url_mapping.get(model_choice) if url_mapping else None
706
+
707
+ if not server_url:
708
+ return (
709
+ gr.update(value="Could not find server URL for selected model. Please rediscover models.", visible=True),
710
+ gr.update(value="", visible=True),
711
+ None,
712
+ )
713
+
714
+ # Get model info
715
+ model_info_text = get_model_info_for_url(server_url) or ""
716
+ status, health_data, _ = check_server_health(server_url)
717
+
718
+ _server_state["server_url"] = server_url
719
+
720
+ return (
721
+ gr.update(value=status, visible=True),
722
+ gr.update(value=model_info_text, visible=True),
723
+ server_url,
724
+ )
725
 
726
+ discover_btn.click(
727
+ fn=discover_and_select_models,
728
+ inputs=[base_url_input],
729
+ outputs=[model_dropdown, server_status, model_info_display, server_url_state, model_url_mapping_state],
730
+ )
 
 
731
 
732
+ model_dropdown.change(
733
+ fn=on_model_selected,
734
+ inputs=[model_dropdown, model_url_mapping_state],
735
+ outputs=[server_status, model_info_display, server_url_state],
736
+ )
737
 
738
  with gr.Tab("Progress Prediction"):
739
  gr.Markdown("### Progress & Success Prediction")
 
992
 
993
  analyze_single_btn.click(
994
  fn=process_single_video,
995
+ inputs=[single_video_input, task_text_input, server_url_state, fps_input_single],
996
  outputs=[progress_plot, info_output],
997
  api_name="process_single_video",
998
  )
 
1470
 
1471
  analyze_dual_btn.click(
1472
  fn=process_two_videos,
1473
+ inputs=[video_a_input, video_b_input, task_text_dual, prediction_type, server_url_state, fps_input_dual],
1474
  outputs=[result_text, video_a_display, video_b_display],
1475
  api_name="process_two_videos",
1476
  )