TimWindecker commited on
Commit
641159b
·
verified ·
1 Parent(s): 4a5921f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +38 -64
src/streamlit_app.py CHANGED
@@ -1,10 +1,16 @@
 
 
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  import plotly.graph_objects as go
4
  import plotly.express as px
5
  from io import StringIO
6
  import json
7
 
 
 
8
  # Page config
9
  st.set_page_config(
10
  page_title="NaviTrace Leaderboard",
@@ -93,62 +99,48 @@ st.markdown("""
93
  </style>
94
  """, unsafe_allow_html=True)
95
 
96
- # Sample data - Replace with your actual data
97
  def load_data():
98
- return pd.DataFrame({
99
- 'Model': ['GPT-4', 'Claude-3.5-Sonnet', 'Gemini-Pro', 'Llama-3-70B', 'Mistral-Large'],
100
- 'Total Score': [87.5, 85.2, 82.1, 78.3, 75.6],
101
- 'Embodiment-A': [90.2, 87.5, 84.3, 80.1, 77.8],
102
- 'Embodiment-B': [85.8, 84.1, 81.2, 77.9, 74.5],
103
- 'Embodiment-C': [86.5, 84.0, 80.8, 76.9, 74.5],
104
- 'Category-Spatial': [88.9, 86.7, 83.5, 79.8, 76.9],
105
- 'Category-Temporal': [86.3, 84.2, 81.0, 77.5, 75.1],
106
- 'Category-Object': [87.3, 84.7, 81.8, 77.6, 74.8],
107
- })
108
 
109
- def calculate_score(results_df):
110
- """
111
- Calculate score using private test split ground truth.
112
- This function should:
113
- 1. Load the private test split ground truth (not exposed to users)
114
- 2. Compare uploaded predictions with ground truth
115
- 3. Calculate metrics per embodiment and category
116
- 4. Return detailed scores
117
-
118
- Args:
119
- results_df: DataFrame with columns ['sample_id', 'prediction', ...]
120
-
121
- Returns:
122
- dict: Scores breakdown or None if error
123
- """
124
  try:
125
- # TODO: Implement your scoring logic here
126
- # Example structure:
127
- # ground_truth = load_private_test_split() # From secure location
128
- # scores = evaluate_predictions(results_df, ground_truth)
 
 
 
129
 
130
- # Placeholder - replace with actual calculation
131
- scores = {
132
- 'Total Score': 85.0,
133
- 'Embodiment-A': 87.0,
134
- 'Embodiment-B': 84.0,
135
- 'Embodiment-C': 84.0,
136
- 'Category-Spatial': 86.0,
137
- 'Category-Temporal': 85.0,
138
- 'Category-Object': 84.0,
139
- }
140
- return scores
 
 
 
 
 
 
 
 
141
  except Exception as e:
142
  st.error(f"Error calculating score: {str(e)}")
143
  return None
144
 
145
  def validate_tsv_format(uploaded_file):
146
  """Validate that the uploaded TSV has the correct format"""
 
147
  try:
148
  df = pd.read_csv(uploaded_file, sep='\t')
149
- # TODO: Add your specific validation logic
150
  # Check for required columns, data types, etc.
151
- required_cols = ['sample_id', 'prediction'] # Adjust as needed
152
  if not all(col in df.columns for col in required_cols):
153
  return False, f"Missing required columns. Expected: {required_cols}"
154
  return True, df
@@ -157,6 +149,7 @@ def validate_tsv_format(uploaded_file):
157
 
158
  def create_bar_chart(df, view_type):
159
  """Create interactive bar chart based on view type"""
 
160
  if view_type == "Total Score":
161
  fig = go.Figure(data=[
162
  go.Bar(
@@ -233,25 +226,6 @@ def create_bar_chart(df, view_type):
233
 
234
  return fig
235
 
236
- # TODO remove # Serve only the chart as JSON if parameter "only_chart" is set
237
- # # E.g. https://huggingface.co/spaces/leggedrobotics/navitrace_leaderboard/?only_chart=total_score
238
- # params = st.query_params
239
- # if "only_chart" in params and params["only_chart"] in ["total_score", "per_embodiment", "per_category"]:
240
- # if params["only_chart"] == "total_score":
241
- # view_type = "Total Score"
242
- # elif params["only_chart"] == "per_embodiment":
243
- # view_type = "Per Embodiment"
244
- # elif params["only_chart"] == "per_category":
245
- # view_type = "Per Category"
246
-
247
- # # Create chart
248
- # df = load_data()
249
- # fig = create_bar_chart(df, view_type)
250
-
251
- # # Only output JSON
252
- # st.write(fig.to_json())
253
- # st.stop()
254
-
255
  # Header
256
  st.markdown("""
257
  <div class="header-container">
@@ -278,8 +252,8 @@ df = load_data()
278
 
279
  # Add user's model if it exists in session state
280
  if 'user_results' in st.session_state:
281
- user_row = pd.DataFrame([st.session_state.user_results])
282
- df = pd.concat([user_row, df], ignore_index=True)
283
 
284
  # View selector
285
  view_type = st.selectbox(
 
1
+ from src.score_calculation.score import score_predictions
2
+ from datasets import load_dataset
3
+ import multiprocessing
4
  import streamlit as st
5
  import pandas as pd
6
+ from pathlib import Path
7
  import plotly.graph_objects as go
8
  import plotly.express as px
9
  from io import StringIO
10
  import json
11
 
12
+ RESULTS_DIR = "results/"
13
+
14
  # Page config
15
  st.set_page_config(
16
  page_title="NaviTrace Leaderboard",
 
99
  </style>
100
  """, unsafe_allow_html=True)
101
 
 
102
  def load_data():
103
+ """Load all result files as one data frame"""
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  try:
106
+ # Load all results files
107
+ all_dfs = []
108
+ for file_path in Path(RESULTS_DIR).glob('*.tsv'):
109
+ df = pd.read_csv(file_path, sep='\t')
110
+ model_name = file_path.stem
111
+ df["model"] = model_name
112
+ all_dfs.append(df)
113
 
114
+ # Concatenate all DataFrames into one
115
+ if all_dfs:
116
+ final_df = pd.concat(all_dfs, ignore_index=True)
117
+
118
+ return final_df
119
+ except Exception as e:
120
+ st.error(f"Error loading data: {str(e)}")
121
+ return None
122
+
123
+ def calculate_score(results_df):
124
+ """Calculate score using private test split ground truth."""
125
+
126
+ try:
127
+ # Access to private dataset with test labels
128
+ login(token=os.environ.get("HF_TOKEN"))
129
+ dataset = load_dataset(os.environ.get("HF_DATASET_ID"), split="test")
130
+
131
+ # Calculate score
132
+ return score_predictions(results_df, dataset)
133
  except Exception as e:
134
  st.error(f"Error calculating score: {str(e)}")
135
  return None
136
 
137
  def validate_tsv_format(uploaded_file):
138
  """Validate that the uploaded TSV has the correct format"""
139
+
140
  try:
141
  df = pd.read_csv(uploaded_file, sep='\t')
 
142
  # Check for required columns, data types, etc.
143
+ required_cols = ["sample_id", "embodiment", "category", "prediction"]
144
  if not all(col in df.columns for col in required_cols):
145
  return False, f"Missing required columns. Expected: {required_cols}"
146
  return True, df
 
149
 
150
  def create_bar_chart(df, view_type):
151
  """Create interactive bar chart based on view type"""
152
+
153
  if view_type == "Total Score":
154
  fig = go.Figure(data=[
155
  go.Bar(
 
226
 
227
  return fig
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  # Header
230
  st.markdown("""
231
  <div class="header-container">
 
252
 
253
  # Add user's model if it exists in session state
254
  if 'user_results' in st.session_state:
255
+ user_results = pd.DataFrame([st.session_state.user_results])
256
+ df = pd.concat([user_results, df], ignore_index=True)
257
 
258
  # View selector
259
  view_type = st.selectbox(