Anas Awadalla
commited on
Commit
·
79cb6e1
1
Parent(s):
2dbb46e
more analysis + baselines
Browse files- README.md +37 -13
- src/streamlit_app.py +285 -33
README.md
CHANGED
|
@@ -20,10 +20,11 @@ A Streamlit application for visualizing model performance on grounding benchmark
|
|
| 20 |
- **Real-time Data**: Streams results directly from the HuggingFace leaderboard repository without local storage
|
| 21 |
- **Interactive Visualizations**: Bar charts comparing model performance across different metrics
|
| 22 |
- **Baseline Comparisons**: Shows baseline models (Qwen2-VL, UI-TARS) alongside evaluated models
|
| 23 |
-
- **
|
| 24 |
-
|
| 25 |
-
-
|
| 26 |
-
-
|
|
|
|
| 27 |
- **Model Details**: View training loss, checkpoint steps, and evaluation timestamps
|
| 28 |
- **Sample Results**: Inspect the first 5 evaluation samples for each model
|
| 29 |
|
|
@@ -49,15 +50,23 @@ The app will open in your browser at `http://localhost:8501`
|
|
| 49 |
|
| 50 |
2. **Filter Models**: Optionally filter to view a specific model or all models
|
| 51 |
|
| 52 |
-
3. **View Charts**:
|
| 53 |
-
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
4. **Explore Details**:
|
| 58 |
- Expand "Model Details" to see training metadata
|
| 59 |
- Expand "Detailed UI Type Breakdown" for a comprehensive table
|
| 60 |
- Expand "Sample Results" to see the first 5 evaluation samples
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
## Data Source
|
| 63 |
|
|
@@ -75,7 +84,7 @@ To minimize local storage requirements, the app:
|
|
| 75 |
|
| 76 |
## Supported Datasets
|
| 77 |
|
| 78 |
-
- **ScreenSpot-v2**: Web and desktop UI element grounding
|
| 79 |
- **ScreenSpot-Pro**: Professional UI grounding benchmark
|
| 80 |
- **ShowdownClicks**: Click prediction benchmark
|
| 81 |
- And more as they are added to the leaderboard
|
|
@@ -83,10 +92,25 @@ To minimize local storage requirements, the app:
|
|
| 83 |
## Baseline Models
|
| 84 |
|
| 85 |
For ScreenSpot-v2, the following baselines are included:
|
| 86 |
-
- Qwen2-VL-7B
|
| 87 |
-
- UI-TARS-2B
|
| 88 |
-
- UI-TARS-7B
|
| 89 |
-
- UI-TARS-72B
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
## Caching
|
| 92 |
|
|
|
|
| 20 |
- **Real-time Data**: Streams results directly from the HuggingFace leaderboard repository without local storage
|
| 21 |
- **Interactive Visualizations**: Bar charts comparing model performance across different metrics
|
| 22 |
- **Baseline Comparisons**: Shows baseline models (Qwen2-VL, UI-TARS) alongside evaluated models
|
| 23 |
+
- **Best Checkpoint Selection**: Automatically shows the best performing checkpoint for each model (marked with * if not the last checkpoint)
|
| 24 |
+
- **UI Type Breakdown**:
|
| 25 |
+
- For ScreenSpot-v2: Comprehensive charts showing Overall, Desktop, Web, and individual UI type performance
|
| 26 |
+
- For other datasets: Desktop vs Web and Text vs Icon performance
|
| 27 |
+
- **Checkpoint Progression Analysis**: Visualize how metrics evolve during training
|
| 28 |
- **Model Details**: View training loss, checkpoint steps, and evaluation timestamps
|
| 29 |
- **Sample Results**: Inspect the first 5 evaluation samples for each model
|
| 30 |
|
|
|
|
| 50 |
|
| 51 |
2. **Filter Models**: Optionally filter to view a specific model or all models
|
| 52 |
|
| 53 |
+
3. **View Charts**:
|
| 54 |
+
- For ScreenSpot-v2:
|
| 55 |
+
- Overall performance (average of desktop and web)
|
| 56 |
+
- Desktop and Web averages
|
| 57 |
+
- Individual UI type metrics: Desktop (Text), Desktop (Icon), Web (Text), Web (Icon)
|
| 58 |
+
- Text and Icon averages across environments
|
| 59 |
+
- Baseline model comparisons shown in orange
|
| 60 |
+
- Models marked with * indicate the best checkpoint is not the final one
|
| 61 |
|
| 62 |
4. **Explore Details**:
|
| 63 |
- Expand "Model Details" to see training metadata
|
| 64 |
- Expand "Detailed UI Type Breakdown" for a comprehensive table
|
| 65 |
- Expand "Sample Results" to see the first 5 evaluation samples
|
| 66 |
+
- Expand "Checkpoint Progression Analysis" to:
|
| 67 |
+
- View accuracy progression over training steps
|
| 68 |
+
- See the relationship between training loss and accuracy
|
| 69 |
+
- Compare performance across checkpoints
|
| 70 |
|
| 71 |
## Data Source
|
| 72 |
|
|
|
|
| 84 |
|
| 85 |
## Supported Datasets
|
| 86 |
|
| 87 |
+
- **ScreenSpot-v2**: Web and desktop UI element grounding (with special handling for desktop/web averaging)
|
| 88 |
- **ScreenSpot-Pro**: Professional UI grounding benchmark
|
| 89 |
- **ShowdownClicks**: Click prediction benchmark
|
| 90 |
- And more as they are added to the leaderboard
|
|
|
|
| 92 |
## Baseline Models
|
| 93 |
|
| 94 |
For ScreenSpot-v2, the following baselines are included:
|
| 95 |
+
- Qwen2-VL-7B: 37.96%
|
| 96 |
+
- UI-TARS-2B: 82.8%
|
| 97 |
+
- UI-TARS-7B: 92.2%
|
| 98 |
+
- UI-TARS-72B: 88.3%
|
| 99 |
+
|
| 100 |
+
For ScreenSpot-Pro, the following baselines are included:
|
| 101 |
+
- Qwen2.5-VL-3B-Instruct: 16.1%
|
| 102 |
+
- Qwen2.5-VL-7B-Instruct: 26.8%
|
| 103 |
+
- Qwen2.5-VL-72B-Instruct: 53.3%
|
| 104 |
+
- UI-TARS-2B: 27.7%
|
| 105 |
+
- UI-TARS-7B: 35.7%
|
| 106 |
+
- UI-TARS-72B: 38.1%
|
| 107 |
+
|
| 108 |
+
## Checkpoint Handling
|
| 109 |
+
|
| 110 |
+
- The app automatically identifies the best performing checkpoint for each model
|
| 111 |
+
- If multiple checkpoints exist, only the best one is shown in the main charts
|
| 112 |
+
- An asterisk (*) indicates when the best checkpoint is not the last one
|
| 113 |
+
- Use the "Checkpoint Progression Analysis" to explore all checkpoints
|
| 114 |
|
| 115 |
## Caching
|
| 116 |
|
src/streamlit_app.py
CHANGED
|
@@ -53,6 +53,26 @@ BASELINES = {
|
|
| 53 |
"web_icon": 86.3,
|
| 54 |
"overall": 88.3
|
| 55 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
}
|
| 57 |
}
|
| 58 |
|
|
@@ -99,18 +119,36 @@ def fetch_leaderboard_data():
|
|
| 99 |
# Get model name from metadata or path
|
| 100 |
model_checkpoint = metadata.get("model_checkpoint", "")
|
| 101 |
model_name = model_checkpoint.split('/')[-1]
|
|
|
|
|
|
|
| 102 |
|
| 103 |
# Handle checkpoint names
|
| 104 |
if not model_name and len(path_parts) > 2:
|
| 105 |
# Check if it's a checkpoint subdirectory structure
|
| 106 |
if len(path_parts) > 3 and path_parts[2] != path_parts[3]:
|
| 107 |
# Format: grounding/dataset/base_model/checkpoint.json
|
| 108 |
-
|
| 109 |
checkpoint_file = path_parts[3].replace(".json", "")
|
| 110 |
-
model_name = f"{
|
|
|
|
| 111 |
else:
|
| 112 |
# Regular format: grounding/dataset/results_modelname.json
|
| 113 |
model_name = path_parts[2].replace("results_", "").replace(".json", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# Extract UI type results if available
|
| 116 |
ui_type_results = detailed_results.get("by_ui_type", {})
|
|
@@ -120,6 +158,8 @@ def fetch_leaderboard_data():
|
|
| 120 |
result_entry = {
|
| 121 |
"dataset": dataset_name,
|
| 122 |
"model": model_name,
|
|
|
|
|
|
|
| 123 |
"model_path": model_checkpoint,
|
| 124 |
"overall_accuracy": metrics.get("accuracy", 0) * 100, # Convert to percentage
|
| 125 |
"total_samples": metrics.get("total", 0),
|
|
@@ -145,7 +185,49 @@ def fetch_leaderboard_data():
|
|
| 145 |
progress_bar.empty()
|
| 146 |
status_text.empty()
|
| 147 |
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
except Exception as e:
|
| 151 |
st.error(f"Error fetching leaderboard data: {str(e)}")
|
|
@@ -164,17 +246,23 @@ def parse_ui_type_metrics(df: pd.DataFrame, dataset_filter: str) -> pd.DataFrame
|
|
| 164 |
|
| 165 |
# For ScreenSpot datasets, we have desktop/web and text/icon
|
| 166 |
if 'screenspot' in dataset_filter.lower():
|
| 167 |
-
# Calculate
|
| 168 |
desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
|
| 169 |
desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
|
| 170 |
web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
|
| 171 |
web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
|
| 172 |
|
| 173 |
# Calculate averages
|
| 174 |
-
desktop_avg = (desktop_text + desktop_icon) / 2 if desktop_text or desktop_icon else 0
|
| 175 |
-
web_avg = (web_text + web_icon) / 2 if web_text or web_icon else 0
|
| 176 |
-
text_avg = (desktop_text + web_text) / 2 if desktop_text or web_text else 0
|
| 177 |
-
icon_avg = (desktop_icon + web_icon) / 2 if desktop_icon or web_icon else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
metrics_list.append({
|
| 180 |
'model': model,
|
|
@@ -186,7 +274,9 @@ def parse_ui_type_metrics(df: pd.DataFrame, dataset_filter: str) -> pd.DataFrame
|
|
| 186 |
'web_avg': web_avg,
|
| 187 |
'text_avg': text_avg,
|
| 188 |
'icon_avg': icon_avg,
|
| 189 |
-
'overall':
|
|
|
|
|
|
|
| 190 |
})
|
| 191 |
|
| 192 |
return pd.DataFrame(metrics_list)
|
|
@@ -303,35 +393,197 @@ def main():
|
|
| 303 |
if not ui_metrics_df.empty and 'screenspot' in selected_dataset.lower():
|
| 304 |
st.subheader("Performance by UI Type")
|
| 305 |
|
| 306 |
-
#
|
| 307 |
-
|
|
|
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
st.altair_chart(chart, use_container_width=True)
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
# Detailed breakdown
|
| 337 |
with st.expander("Detailed UI Type Breakdown"):
|
|
|
|
| 53 |
"web_icon": 86.3,
|
| 54 |
"overall": 88.3
|
| 55 |
}
|
| 56 |
+
},
|
| 57 |
+
"screenspot-pro": {
|
| 58 |
+
"Qwen2.5-VL-3B-Instruct": {
|
| 59 |
+
"overall": 16.1
|
| 60 |
+
},
|
| 61 |
+
"Qwen2.5-VL-7B-Instruct": {
|
| 62 |
+
"overall": 26.8
|
| 63 |
+
},
|
| 64 |
+
"Qwen2.5-VL-72B-Instruct": {
|
| 65 |
+
"overall": 53.3
|
| 66 |
+
},
|
| 67 |
+
"UI-TARS-2B": {
|
| 68 |
+
"overall": 27.7
|
| 69 |
+
},
|
| 70 |
+
"UI-TARS-7B": {
|
| 71 |
+
"overall": 35.7
|
| 72 |
+
},
|
| 73 |
+
"UI-TARS-72B": {
|
| 74 |
+
"overall": 38.1
|
| 75 |
+
}
|
| 76 |
}
|
| 77 |
}
|
| 78 |
|
|
|
|
| 119 |
# Get model name from metadata or path
|
| 120 |
model_checkpoint = metadata.get("model_checkpoint", "")
|
| 121 |
model_name = model_checkpoint.split('/')[-1]
|
| 122 |
+
base_model_name = None
|
| 123 |
+
is_checkpoint = False
|
| 124 |
|
| 125 |
# Handle checkpoint names
|
| 126 |
if not model_name and len(path_parts) > 2:
|
| 127 |
# Check if it's a checkpoint subdirectory structure
|
| 128 |
if len(path_parts) > 3 and path_parts[2] != path_parts[3]:
|
| 129 |
# Format: grounding/dataset/base_model/checkpoint.json
|
| 130 |
+
base_model_name = path_parts[2]
|
| 131 |
checkpoint_file = path_parts[3].replace(".json", "")
|
| 132 |
+
model_name = f"{base_model_name}/{checkpoint_file}"
|
| 133 |
+
is_checkpoint = True
|
| 134 |
else:
|
| 135 |
# Regular format: grounding/dataset/results_modelname.json
|
| 136 |
model_name = path_parts[2].replace("results_", "").replace(".json", "")
|
| 137 |
+
base_model_name = model_name
|
| 138 |
+
|
| 139 |
+
# Check if model name indicates a checkpoint
|
| 140 |
+
if 'checkpoint-' in model_name:
|
| 141 |
+
is_checkpoint = True
|
| 142 |
+
if not base_model_name:
|
| 143 |
+
# Extract base model name from full path
|
| 144 |
+
if '/' in model_name:
|
| 145 |
+
parts = model_name.split('/')
|
| 146 |
+
base_model_name = parts[0]
|
| 147 |
+
else:
|
| 148 |
+
# Try to get from model_checkpoint path
|
| 149 |
+
checkpoint_parts = model_checkpoint.split('/')
|
| 150 |
+
if len(checkpoint_parts) > 1:
|
| 151 |
+
base_model_name = checkpoint_parts[-2]
|
| 152 |
|
| 153 |
# Extract UI type results if available
|
| 154 |
ui_type_results = detailed_results.get("by_ui_type", {})
|
|
|
|
| 158 |
result_entry = {
|
| 159 |
"dataset": dataset_name,
|
| 160 |
"model": model_name,
|
| 161 |
+
"base_model": base_model_name or model_name,
|
| 162 |
+
"is_checkpoint": is_checkpoint,
|
| 163 |
"model_path": model_checkpoint,
|
| 164 |
"overall_accuracy": metrics.get("accuracy", 0) * 100, # Convert to percentage
|
| 165 |
"total_samples": metrics.get("total", 0),
|
|
|
|
| 185 |
progress_bar.empty()
|
| 186 |
status_text.empty()
|
| 187 |
|
| 188 |
+
# Create DataFrame
|
| 189 |
+
df = pd.DataFrame(results)
|
| 190 |
+
|
| 191 |
+
# Process checkpoints: for each base model, find the best checkpoint
|
| 192 |
+
if not df.empty:
|
| 193 |
+
# Group by dataset and base_model
|
| 194 |
+
grouped = df.groupby(['dataset', 'base_model'])
|
| 195 |
+
|
| 196 |
+
# For each group, find the best checkpoint
|
| 197 |
+
best_models = []
|
| 198 |
+
for (dataset, base_model), group in grouped:
|
| 199 |
+
if len(group) > 1:
|
| 200 |
+
# Multiple entries for this model (likely checkpoints)
|
| 201 |
+
best_idx = group['overall_accuracy'].idxmax()
|
| 202 |
+
best_row = group.loc[best_idx].copy()
|
| 203 |
+
|
| 204 |
+
# Check if the best is the last checkpoint
|
| 205 |
+
checkpoint_steps = group[group['checkpoint_steps'].notna()]['checkpoint_steps'].sort_values()
|
| 206 |
+
if len(checkpoint_steps) > 0:
|
| 207 |
+
last_checkpoint_steps = checkpoint_steps.iloc[-1]
|
| 208 |
+
best_checkpoint_steps = best_row['checkpoint_steps']
|
| 209 |
+
if pd.notna(best_checkpoint_steps) and best_checkpoint_steps != last_checkpoint_steps:
|
| 210 |
+
# Best checkpoint is not the last one, add asterisk
|
| 211 |
+
best_row['model'] = best_row['model'] + '*'
|
| 212 |
+
best_row['is_best_not_last'] = True
|
| 213 |
+
else:
|
| 214 |
+
best_row['is_best_not_last'] = False
|
| 215 |
+
|
| 216 |
+
# Store all checkpoints for this model
|
| 217 |
+
best_row['all_checkpoints'] = group.to_dict('records')
|
| 218 |
+
best_models.append(best_row)
|
| 219 |
+
else:
|
| 220 |
+
# Single entry for this model
|
| 221 |
+
row = group.iloc[0].copy()
|
| 222 |
+
row['is_best_not_last'] = False
|
| 223 |
+
row['all_checkpoints'] = [row.to_dict()]
|
| 224 |
+
best_models.append(row)
|
| 225 |
+
|
| 226 |
+
# Create new dataframe with best models
|
| 227 |
+
df_best = pd.DataFrame(best_models)
|
| 228 |
+
return df_best
|
| 229 |
+
|
| 230 |
+
return df
|
| 231 |
|
| 232 |
except Exception as e:
|
| 233 |
st.error(f"Error fetching leaderboard data: {str(e)}")
|
|
|
|
| 246 |
|
| 247 |
# For ScreenSpot datasets, we have desktop/web and text/icon
|
| 248 |
if 'screenspot' in dataset_filter.lower():
|
| 249 |
+
# Calculate individual metrics
|
| 250 |
desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
|
| 251 |
desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
|
| 252 |
web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
|
| 253 |
web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
|
| 254 |
|
| 255 |
# Calculate averages
|
| 256 |
+
desktop_avg = (desktop_text + desktop_icon) / 2 if (desktop_text > 0 or desktop_icon > 0) else 0
|
| 257 |
+
web_avg = (web_text + web_icon) / 2 if (web_text > 0 or web_icon > 0) else 0
|
| 258 |
+
text_avg = (desktop_text + web_text) / 2 if (desktop_text > 0 or web_text > 0) else 0
|
| 259 |
+
icon_avg = (desktop_icon + web_icon) / 2 if (desktop_icon > 0 or web_icon > 0) else 0
|
| 260 |
+
|
| 261 |
+
# For screenspot-v2, calculate the overall as average of desktop and web
|
| 262 |
+
if dataset_filter == 'screenspot-v2':
|
| 263 |
+
overall = (desktop_avg + web_avg) / 2 if (desktop_avg > 0 or web_avg > 0) else 0
|
| 264 |
+
else:
|
| 265 |
+
overall = row['overall_accuracy']
|
| 266 |
|
| 267 |
metrics_list.append({
|
| 268 |
'model': model,
|
|
|
|
| 274 |
'web_avg': web_avg,
|
| 275 |
'text_avg': text_avg,
|
| 276 |
'icon_avg': icon_avg,
|
| 277 |
+
'overall': overall,
|
| 278 |
+
'is_best_not_last': row.get('is_best_not_last', False),
|
| 279 |
+
'all_checkpoints': row.get('all_checkpoints', [])
|
| 280 |
})
|
| 281 |
|
| 282 |
return pd.DataFrame(metrics_list)
|
|
|
|
| 393 |
if not ui_metrics_df.empty and 'screenspot' in selected_dataset.lower():
|
| 394 |
st.subheader("Performance by UI Type")
|
| 395 |
|
| 396 |
+
# Add note about asterisks
|
| 397 |
+
if any(ui_metrics_df['is_best_not_last']):
|
| 398 |
+
st.info("* indicates the best checkpoint is not the last checkpoint")
|
| 399 |
|
| 400 |
+
# Create charts in a grid
|
| 401 |
+
if selected_dataset == 'screenspot-v2':
|
| 402 |
+
# First row: Overall, Desktop, Web averages
|
| 403 |
+
col1, col2, col3 = st.columns(3)
|
| 404 |
+
|
| 405 |
+
with col1:
|
| 406 |
+
chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average (Desktop + Web) / 2')
|
| 407 |
+
if chart:
|
| 408 |
+
st.altair_chart(chart, use_container_width=True)
|
| 409 |
+
|
| 410 |
+
with col2:
|
| 411 |
+
chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
|
| 412 |
+
if chart:
|
| 413 |
+
st.altair_chart(chart, use_container_width=True)
|
| 414 |
+
|
| 415 |
+
with col3:
|
| 416 |
+
chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
|
| 417 |
+
if chart:
|
| 418 |
+
st.altair_chart(chart, use_container_width=True)
|
| 419 |
+
|
| 420 |
+
# Second row: Individual UI type metrics
|
| 421 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 422 |
+
|
| 423 |
+
with col1:
|
| 424 |
+
chart = create_bar_chart(ui_metrics_df, 'desktop_text', 'Desktop (Text)')
|
| 425 |
+
if chart:
|
| 426 |
+
st.altair_chart(chart, use_container_width=True)
|
| 427 |
|
| 428 |
+
with col2:
|
| 429 |
+
chart = create_bar_chart(ui_metrics_df, 'desktop_icon', 'Desktop (Icon)')
|
| 430 |
+
if chart:
|
| 431 |
+
st.altair_chart(chart, use_container_width=True)
|
| 432 |
|
| 433 |
+
with col3:
|
| 434 |
+
chart = create_bar_chart(ui_metrics_df, 'web_text', 'Web (Text)')
|
| 435 |
+
if chart:
|
| 436 |
+
st.altair_chart(chart, use_container_width=True)
|
| 437 |
+
|
| 438 |
+
with col4:
|
| 439 |
+
chart = create_bar_chart(ui_metrics_df, 'web_icon', 'Web (Icon)')
|
| 440 |
+
if chart:
|
| 441 |
+
st.altair_chart(chart, use_container_width=True)
|
| 442 |
+
|
| 443 |
+
# Third row: Text vs Icon averages
|
| 444 |
+
col1, col2 = st.columns(2)
|
| 445 |
+
|
| 446 |
+
with col1:
|
| 447 |
+
chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (Desktop + Web)')
|
| 448 |
+
if chart:
|
| 449 |
+
st.altair_chart(chart, use_container_width=True)
|
| 450 |
+
|
| 451 |
+
with col2:
|
| 452 |
+
chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (Desktop + Web)')
|
| 453 |
+
if chart:
|
| 454 |
+
st.altair_chart(chart, use_container_width=True)
|
| 455 |
+
else:
|
| 456 |
+
# For other screenspot datasets, show the standard layout
|
| 457 |
+
col1, col2 = st.columns(2)
|
| 458 |
+
|
| 459 |
+
with col1:
|
| 460 |
+
# Overall Average
|
| 461 |
+
chart = create_bar_chart(ui_metrics_df, 'overall', 'Overall Average')
|
| 462 |
+
if chart:
|
| 463 |
+
st.altair_chart(chart, use_container_width=True)
|
| 464 |
+
|
| 465 |
+
# Desktop Average
|
| 466 |
+
chart = create_bar_chart(ui_metrics_df, 'desktop_avg', 'Desktop Average')
|
| 467 |
+
if chart:
|
| 468 |
+
st.altair_chart(chart, use_container_width=True)
|
| 469 |
+
|
| 470 |
+
# Text Average
|
| 471 |
+
chart = create_bar_chart(ui_metrics_df, 'text_avg', 'Text Average (UI-Type)')
|
| 472 |
+
if chart:
|
| 473 |
+
st.altair_chart(chart, use_container_width=True)
|
| 474 |
+
|
| 475 |
+
with col2:
|
| 476 |
+
# Web Average
|
| 477 |
+
chart = create_bar_chart(ui_metrics_df, 'web_avg', 'Web Average')
|
| 478 |
+
if chart:
|
| 479 |
+
st.altair_chart(chart, use_container_width=True)
|
| 480 |
+
|
| 481 |
+
# Icon Average
|
| 482 |
+
chart = create_bar_chart(ui_metrics_df, 'icon_avg', 'Icon Average (UI-Type)')
|
| 483 |
+
if chart:
|
| 484 |
+
st.altair_chart(chart, use_container_width=True)
|
| 485 |
|
| 486 |
+
# Checkpoint progression visualization
|
| 487 |
+
with st.expander("Checkpoint Progression Analysis"):
|
| 488 |
+
# Select a model with checkpoints
|
| 489 |
+
models_with_checkpoints = ui_metrics_df[ui_metrics_df['all_checkpoints'].apply(lambda x: len(x) > 1)]
|
|
|
|
| 490 |
|
| 491 |
+
if not models_with_checkpoints.empty:
|
| 492 |
+
selected_checkpoint_model = st.selectbox(
|
| 493 |
+
"Select a model to view checkpoint progression:",
|
| 494 |
+
models_with_checkpoints['model'].str.replace('*', '').unique()
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
# Get checkpoint data for selected model
|
| 498 |
+
model_row = models_with_checkpoints[models_with_checkpoints['model'].str.replace('*', '') == selected_checkpoint_model].iloc[0]
|
| 499 |
+
checkpoint_data = model_row['all_checkpoints']
|
| 500 |
+
|
| 501 |
+
# Create DataFrame from checkpoint data
|
| 502 |
+
checkpoint_df = pd.DataFrame(checkpoint_data)
|
| 503 |
+
|
| 504 |
+
# Prepare data for visualization
|
| 505 |
+
checkpoint_metrics = []
|
| 506 |
+
for _, cp in checkpoint_df.iterrows():
|
| 507 |
+
ui_results = cp['ui_type_results']
|
| 508 |
+
|
| 509 |
+
# Calculate metrics
|
| 510 |
+
desktop_text = ui_results.get('desktop_text', {}).get('correct', 0) / max(ui_results.get('desktop_text', {}).get('total', 1), 1) * 100
|
| 511 |
+
desktop_icon = ui_results.get('desktop_icon', {}).get('correct', 0) / max(ui_results.get('desktop_icon', {}).get('total', 1), 1) * 100
|
| 512 |
+
web_text = ui_results.get('web_text', {}).get('correct', 0) / max(ui_results.get('web_text', {}).get('total', 1), 1) * 100
|
| 513 |
+
web_icon = ui_results.get('web_icon', {}).get('correct', 0) / max(ui_results.get('web_icon', {}).get('total', 1), 1) * 100
|
| 514 |
+
|
| 515 |
+
desktop_avg = (desktop_text + desktop_icon) / 2
|
| 516 |
+
web_avg = (web_text + web_icon) / 2
|
| 517 |
+
overall = (desktop_avg + web_avg) / 2 if selected_dataset == 'screenspot-v2' else cp['overall_accuracy']
|
| 518 |
+
|
| 519 |
+
checkpoint_metrics.append({
|
| 520 |
+
'steps': cp['checkpoint_steps'] or 0,
|
| 521 |
+
'overall': overall,
|
| 522 |
+
'desktop': desktop_avg,
|
| 523 |
+
'web': web_avg,
|
| 524 |
+
'loss': cp['training_loss'],
|
| 525 |
+
'neg_log_loss': -np.log(cp['training_loss']) if cp['training_loss'] and cp['training_loss'] > 0 else None
|
| 526 |
+
})
|
| 527 |
+
|
| 528 |
+
metrics_df = pd.DataFrame(checkpoint_metrics).sort_values('steps')
|
| 529 |
+
|
| 530 |
+
# Plot metrics over training steps
|
| 531 |
+
col1, col2 = st.columns(2)
|
| 532 |
+
|
| 533 |
+
with col1:
|
| 534 |
+
st.write("**Accuracy over Training Steps**")
|
| 535 |
+
|
| 536 |
+
# Melt data for multi-line chart
|
| 537 |
+
melted = metrics_df[['steps', 'overall', 'desktop', 'web']].melt(
|
| 538 |
+
id_vars=['steps'],
|
| 539 |
+
var_name='Metric',
|
| 540 |
+
value_name='Accuracy'
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
chart = alt.Chart(melted).mark_line(point=True).encode(
|
| 544 |
+
x=alt.X('steps:Q', title='Training Steps'),
|
| 545 |
+
y=alt.Y('Accuracy:Q', scale=alt.Scale(domain=[0, 100]), title='Accuracy (%)'),
|
| 546 |
+
color=alt.Color('Metric:N', scale=alt.Scale(
|
| 547 |
+
domain=['overall', 'desktop', 'web'],
|
| 548 |
+
range=['#4ECDC4', '#45B7D1', '#96CEB4']
|
| 549 |
+
)),
|
| 550 |
+
tooltip=['steps', 'Metric', 'Accuracy']
|
| 551 |
+
).properties(
|
| 552 |
+
width=400,
|
| 553 |
+
height=300,
|
| 554 |
+
title='Accuracy Progression During Training'
|
| 555 |
+
)
|
| 556 |
+
st.altair_chart(chart, use_container_width=True)
|
| 557 |
+
|
| 558 |
+
with col2:
|
| 559 |
+
st.write("**Accuracy vs. Training Loss**")
|
| 560 |
+
|
| 561 |
+
if metrics_df['neg_log_loss'].notna().any():
|
| 562 |
+
scatter_data = metrics_df[metrics_df['neg_log_loss'].notna()]
|
| 563 |
+
|
| 564 |
+
chart = alt.Chart(scatter_data).mark_circle(size=100).encode(
|
| 565 |
+
x=alt.X('neg_log_loss:Q', title='-log(Training Loss)'),
|
| 566 |
+
y=alt.Y('overall:Q', scale=alt.Scale(domain=[0, 100]), title='Overall Accuracy (%)'),
|
| 567 |
+
color=alt.Color('steps:Q', scale=alt.Scale(scheme='viridis'), title='Training Steps'),
|
| 568 |
+
tooltip=['steps', 'loss', 'overall']
|
| 569 |
+
).properties(
|
| 570 |
+
width=400,
|
| 571 |
+
height=300,
|
| 572 |
+
title='Accuracy vs. -log(Training Loss)'
|
| 573 |
+
)
|
| 574 |
+
st.altair_chart(chart, use_container_width=True)
|
| 575 |
+
else:
|
| 576 |
+
st.info("No training loss data available for this model")
|
| 577 |
+
|
| 578 |
+
# Show checkpoint details table
|
| 579 |
+
st.write("**Checkpoint Details**")
|
| 580 |
+
display_metrics = metrics_df[['steps', 'overall', 'desktop', 'web', 'loss']].copy()
|
| 581 |
+
display_metrics.columns = ['Steps', 'Overall %', 'Desktop %', 'Web %', 'Training Loss']
|
| 582 |
+
display_metrics[['Overall %', 'Desktop %', 'Web %']] = display_metrics[['Overall %', 'Desktop %', 'Web %']].round(2)
|
| 583 |
+
display_metrics['Training Loss'] = display_metrics['Training Loss'].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")
|
| 584 |
+
st.dataframe(display_metrics, use_container_width=True)
|
| 585 |
+
else:
|
| 586 |
+
st.info("No models with multiple checkpoints available for progression analysis")
|
| 587 |
|
| 588 |
# Detailed breakdown
|
| 589 |
with st.expander("Detailed UI Type Breakdown"):
|