Update src/streamlit_app.py
Browse files- src/streamlit_app.py +38 -64
src/streamlit_app.py
CHANGED
|
@@ -1,10 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
|
|
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import plotly.express as px
|
| 5 |
from io import StringIO
|
| 6 |
import json
|
| 7 |
|
|
|
|
|
|
|
| 8 |
# Page config
|
| 9 |
st.set_page_config(
|
| 10 |
page_title="NaviTrace Leaderboard",
|
|
@@ -93,62 +99,48 @@ st.markdown("""
|
|
| 93 |
</style>
|
| 94 |
""", unsafe_allow_html=True)
|
| 95 |
|
| 96 |
-
# Sample data - Replace with your actual data
|
| 97 |
def load_data():
|
| 98 |
-
|
| 99 |
-
'Model': ['GPT-4', 'Claude-3.5-Sonnet', 'Gemini-Pro', 'Llama-3-70B', 'Mistral-Large'],
|
| 100 |
-
'Total Score': [87.5, 85.2, 82.1, 78.3, 75.6],
|
| 101 |
-
'Embodiment-A': [90.2, 87.5, 84.3, 80.1, 77.8],
|
| 102 |
-
'Embodiment-B': [85.8, 84.1, 81.2, 77.9, 74.5],
|
| 103 |
-
'Embodiment-C': [86.5, 84.0, 80.8, 76.9, 74.5],
|
| 104 |
-
'Category-Spatial': [88.9, 86.7, 83.5, 79.8, 76.9],
|
| 105 |
-
'Category-Temporal': [86.3, 84.2, 81.0, 77.5, 75.1],
|
| 106 |
-
'Category-Object': [87.3, 84.7, 81.8, 77.6, 74.8],
|
| 107 |
-
})
|
| 108 |
|
| 109 |
-
def calculate_score(results_df):
|
| 110 |
-
"""
|
| 111 |
-
Calculate score using private test split ground truth.
|
| 112 |
-
This function should:
|
| 113 |
-
1. Load the private test split ground truth (not exposed to users)
|
| 114 |
-
2. Compare uploaded predictions with ground truth
|
| 115 |
-
3. Calculate metrics per embodiment and category
|
| 116 |
-
4. Return detailed scores
|
| 117 |
-
|
| 118 |
-
Args:
|
| 119 |
-
results_df: DataFrame with columns ['sample_id', 'prediction', ...]
|
| 120 |
-
|
| 121 |
-
Returns:
|
| 122 |
-
dict: Scores breakdown or None if error
|
| 123 |
-
"""
|
| 124 |
try:
|
| 125 |
-
#
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
#
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
except Exception as e:
|
| 142 |
st.error(f"Error calculating score: {str(e)}")
|
| 143 |
return None
|
| 144 |
|
| 145 |
def validate_tsv_format(uploaded_file):
|
| 146 |
"""Validate that the uploaded TSV has the correct format"""
|
|
|
|
| 147 |
try:
|
| 148 |
df = pd.read_csv(uploaded_file, sep='\t')
|
| 149 |
-
# TODO: Add your specific validation logic
|
| 150 |
# Check for required columns, data types, etc.
|
| 151 |
-
required_cols = [
|
| 152 |
if not all(col in df.columns for col in required_cols):
|
| 153 |
return False, f"Missing required columns. Expected: {required_cols}"
|
| 154 |
return True, df
|
|
@@ -157,6 +149,7 @@ def validate_tsv_format(uploaded_file):
|
|
| 157 |
|
| 158 |
def create_bar_chart(df, view_type):
|
| 159 |
"""Create interactive bar chart based on view type"""
|
|
|
|
| 160 |
if view_type == "Total Score":
|
| 161 |
fig = go.Figure(data=[
|
| 162 |
go.Bar(
|
|
@@ -233,25 +226,6 @@ def create_bar_chart(df, view_type):
|
|
| 233 |
|
| 234 |
return fig
|
| 235 |
|
| 236 |
-
# TODO remove # Serve only the chart as JSON if parameter "only_chart" is set
|
| 237 |
-
# # E.g. https://huggingface.co/spaces/leggedrobotics/navitrace_leaderboard/?only_chart=total_score
|
| 238 |
-
# params = st.query_params
|
| 239 |
-
# if "only_chart" in params and params["only_chart"] in ["total_score", "per_embodiment", "per_category"]:
|
| 240 |
-
# if params["only_chart"] == "total_score":
|
| 241 |
-
# view_type = "Total Score"
|
| 242 |
-
# elif params["only_chart"] == "per_embodiment":
|
| 243 |
-
# view_type = "Per Embodiment"
|
| 244 |
-
# elif params["only_chart"] == "per_category":
|
| 245 |
-
# view_type = "Per Category"
|
| 246 |
-
|
| 247 |
-
# # Create chart
|
| 248 |
-
# df = load_data()
|
| 249 |
-
# fig = create_bar_chart(df, view_type)
|
| 250 |
-
|
| 251 |
-
# # Only output JSON
|
| 252 |
-
# st.write(fig.to_json())
|
| 253 |
-
# st.stop()
|
| 254 |
-
|
| 255 |
# Header
|
| 256 |
st.markdown("""
|
| 257 |
<div class="header-container">
|
|
@@ -278,8 +252,8 @@ df = load_data()
|
|
| 278 |
|
| 279 |
# Add user's model if it exists in session state
|
| 280 |
if 'user_results' in st.session_state:
|
| 281 |
-
|
| 282 |
-
df = pd.concat([
|
| 283 |
|
| 284 |
# View selector
|
| 285 |
view_type = st.selectbox(
|
|
|
|
| 1 |
+
from src.score_calculation.score import score_predictions
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
import multiprocessing
|
| 4 |
import streamlit as st
|
| 5 |
import pandas as pd
|
| 6 |
+
from pathlib import Path
|
| 7 |
import plotly.graph_objects as go
|
| 8 |
import plotly.express as px
|
| 9 |
from io import StringIO
|
| 10 |
import json
|
| 11 |
|
| 12 |
+
RESULTS_DIR = "results/"
|
| 13 |
+
|
| 14 |
# Page config
|
| 15 |
st.set_page_config(
|
| 16 |
page_title="NaviTrace Leaderboard",
|
|
|
|
| 99 |
</style>
|
| 100 |
""", unsafe_allow_html=True)
|
| 101 |
|
|
|
|
| 102 |
def load_data():
|
| 103 |
+
"""Load all result files as one data frame"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
try:
|
| 106 |
+
# Load all results files
|
| 107 |
+
all_dfs = []
|
| 108 |
+
for file_path in Path(RESULTS_DIR).glob('*.tsv'):
|
| 109 |
+
df = pd.read_csv(file_path, sep='\t')
|
| 110 |
+
model_name = file_path.stem
|
| 111 |
+
df["model"] = model_name
|
| 112 |
+
all_dfs.append(df)
|
| 113 |
|
| 114 |
+
# Concatenate all DataFrames into one
|
| 115 |
+
if all_dfs:
|
| 116 |
+
final_df = pd.concat(all_dfs, ignore_index=True)
|
| 117 |
+
|
| 118 |
+
return final_df
|
| 119 |
+
except Exception as e:
|
| 120 |
+
st.error(f"Error loading data: {str(e)}")
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
def calculate_score(results_df):
|
| 124 |
+
"""Calculate score using private test split ground truth."""
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
# Access to private dataset with test labels
|
| 128 |
+
login(token=os.environ.get("HF_TOKEN"))
|
| 129 |
+
dataset = load_dataset(os.environ.get("HF_DATASET_ID"), split="test")
|
| 130 |
+
|
| 131 |
+
# Calculate score
|
| 132 |
+
return score_predictions(results_df, dataset)
|
| 133 |
except Exception as e:
|
| 134 |
st.error(f"Error calculating score: {str(e)}")
|
| 135 |
return None
|
| 136 |
|
| 137 |
def validate_tsv_format(uploaded_file):
|
| 138 |
"""Validate that the uploaded TSV has the correct format"""
|
| 139 |
+
|
| 140 |
try:
|
| 141 |
df = pd.read_csv(uploaded_file, sep='\t')
|
|
|
|
| 142 |
# Check for required columns, data types, etc.
|
| 143 |
+
required_cols = ["sample_id", "embodiment", "category", "prediction"]
|
| 144 |
if not all(col in df.columns for col in required_cols):
|
| 145 |
return False, f"Missing required columns. Expected: {required_cols}"
|
| 146 |
return True, df
|
|
|
|
| 149 |
|
| 150 |
def create_bar_chart(df, view_type):
|
| 151 |
"""Create interactive bar chart based on view type"""
|
| 152 |
+
|
| 153 |
if view_type == "Total Score":
|
| 154 |
fig = go.Figure(data=[
|
| 155 |
go.Bar(
|
|
|
|
| 226 |
|
| 227 |
return fig
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
# Header
|
| 230 |
st.markdown("""
|
| 231 |
<div class="header-container">
|
|
|
|
| 252 |
|
| 253 |
# Add user's model if it exists in session state
|
| 254 |
if 'user_results' in st.session_state:
|
| 255 |
+
user_results = pd.DataFrame([st.session_state.user_results])
|
| 256 |
+
df = pd.concat([user_results, df], ignore_index=True)
|
| 257 |
|
| 258 |
# View selector
|
| 259 |
view_type = st.selectbox(
|