Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,14 @@ import sys
|
|
| 12 |
import io
|
| 13 |
from contextlib import redirect_stdout, redirect_stderr
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# FILES
|
| 16 |
iteration_output_file = "llm_benchmark_iteration_results.csv" # File to store iteration results, defined as global
|
| 17 |
results_file = "llm_benchmark_results.csv" # all data
|
|
@@ -41,12 +49,6 @@ difficulty_probabilities = {
|
|
| 41 |
"a very difficult": 0.6
|
| 42 |
}
|
| 43 |
|
| 44 |
-
# Create output displays for main log and debug log
|
| 45 |
-
if 'main_output' not in st.session_state:
|
| 46 |
-
st.session_state.main_output = []
|
| 47 |
-
if 'debug_output' not in st.session_state:
|
| 48 |
-
st.session_state.debug_output = []
|
| 49 |
-
|
| 50 |
# Custom print function to capture output
|
| 51 |
def custom_print(*args, **kwargs):
|
| 52 |
# Convert args to string and join with spaces
|
|
@@ -57,11 +59,17 @@ def custom_print(*args, **kwargs):
|
|
| 57 |
|
| 58 |
# Also print to standard output for console logging
|
| 59 |
print(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
# Custom function to capture warnings and errors
|
| 62 |
def log_debug(message):
|
| 63 |
st.session_state.debug_output.append(message)
|
| 64 |
print(f"DEBUG: {message}", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def retry_api_request(max_retries=3, wait_time=10):
|
| 67 |
"""Decorator for retrying API requests with rate limit handling."""
|
|
@@ -724,8 +732,11 @@ def run_benchmark(hf_models, topics, difficulties, t, model_config, token=None):
|
|
| 724 |
|
| 725 |
|
| 726 |
for model_id in active_models:
|
| 727 |
-
answer = answers
|
| 728 |
-
|
|
|
|
|
|
|
|
|
|
| 729 |
if answer == "Error answering": # Handle answer generation errors
|
| 730 |
consecutive_failures[model_id] += 1
|
| 731 |
if consecutive_failures[model_id] >= failure_threshold:
|
|
@@ -794,7 +805,7 @@ def run_benchmark(hf_models, topics, difficulties, t, model_config, token=None):
|
|
| 794 |
results["question_prompt"].append(question_prompt)
|
| 795 |
results["question"].append(question)
|
| 796 |
results["answer"].append(answer)
|
| 797 |
-
results["answer_generation_duration"].append(
|
| 798 |
results["average_rank"].append(average_rank)
|
| 799 |
results["ranks"].append([ranks[m] for m in active_models if m in ranks]) # Store raw ranks including Nones, ensure order
|
| 800 |
results["question_rank_average"].append(question_avg_rank) # Store question rank average
|
|
@@ -816,7 +827,7 @@ def run_benchmark(hf_models, topics, difficulties, t, model_config, token=None):
|
|
| 816 |
total_valid_rank = 0 # Keep track of the sum of valid (non-NaN) ranks
|
| 817 |
|
| 818 |
for m_id in active_models:
|
| 819 |
-
if cumulative_avg_rank[m_id]:
|
| 820 |
temp_weights[m_id] = cumulative_avg_rank[m_id]
|
| 821 |
total_valid_rank += cumulative_avg_rank[m_id]
|
| 822 |
else: # if cumulative is empty, keep original
|
|
@@ -884,10 +895,6 @@ def check_model_availability(models, token):
|
|
| 884 |
# Streamlit UI
|
| 885 |
st.title("LLM Benchmark")
|
| 886 |
|
| 887 |
-
# Initialize session state variables for progress tracking
|
| 888 |
-
if 'progress' not in st.session_state:
|
| 889 |
-
st.session_state.progress = 0
|
| 890 |
-
|
| 891 |
# Setup sidebar for configuration
|
| 892 |
st.sidebar.header("Configuration")
|
| 893 |
|
|
@@ -970,6 +977,7 @@ with tab1:
|
|
| 970 |
# Clear previous outputs
|
| 971 |
st.session_state.main_output = []
|
| 972 |
st.session_state.debug_output = []
|
|
|
|
| 973 |
|
| 974 |
if not hf_token:
|
| 975 |
st.error("Please enter your Hugging Face API token")
|
|
@@ -1038,21 +1046,28 @@ with tab1:
|
|
| 1038 |
with tab2:
|
| 1039 |
# Display main output log
|
| 1040 |
st.subheader("Execution Log")
|
| 1041 |
-
log_container = st.container()
|
| 1042 |
|
| 1043 |
# Display logs
|
| 1044 |
log_text = "\n".join(st.session_state.main_output)
|
| 1045 |
-
|
| 1046 |
|
| 1047 |
# Add a refresh button for the log
|
| 1048 |
-
if st.button("Refresh Log"):
|
| 1049 |
-
|
| 1050 |
|
| 1051 |
with tab3:
|
| 1052 |
# Display debug output
|
| 1053 |
st.subheader("Debug Log")
|
| 1054 |
-
debug_container = st.container()
|
| 1055 |
|
| 1056 |
# Display debug logs
|
| 1057 |
debug_text = "\n".join(st.session_state.debug_output)
|
| 1058 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
import io
|
| 13 |
from contextlib import redirect_stdout, redirect_stderr
|
| 14 |
|
| 15 |
+
# Initialize session state variables
|
| 16 |
+
if 'main_output' not in st.session_state:
|
| 17 |
+
st.session_state.main_output = []
|
| 18 |
+
if 'debug_output' not in st.session_state:
|
| 19 |
+
st.session_state.debug_output = []
|
| 20 |
+
if 'progress' not in st.session_state:
|
| 21 |
+
st.session_state.progress = 0
|
| 22 |
+
|
| 23 |
# FILES
|
| 24 |
iteration_output_file = "llm_benchmark_iteration_results.csv" # File to store iteration results, defined as global
|
| 25 |
results_file = "llm_benchmark_results.csv" # all data
|
|
|
|
| 49 |
"a very difficult": 0.6
|
| 50 |
}
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# Custom print function to capture output
|
| 53 |
def custom_print(*args, **kwargs):
|
| 54 |
# Convert args to string and join with spaces
|
|
|
|
| 59 |
|
| 60 |
# Also print to standard output for console logging
|
| 61 |
print(*args, **kwargs)
|
| 62 |
+
|
| 63 |
+
# Force an immediate update of the UI (when used inside a function)
|
| 64 |
+
st.session_state.update_counter = st.session_state.get('update_counter', 0) + 1
|
| 65 |
|
| 66 |
# Custom function to capture warnings and errors
|
| 67 |
def log_debug(message):
|
| 68 |
st.session_state.debug_output.append(message)
|
| 69 |
print(f"DEBUG: {message}", file=sys.stderr)
|
| 70 |
+
|
| 71 |
+
# Force an immediate update of the UI
|
| 72 |
+
st.session_state.update_counter = st.session_state.get('update_counter', 0) + 1
|
| 73 |
|
| 74 |
def retry_api_request(max_retries=3, wait_time=10):
|
| 75 |
"""Decorator for retrying API requests with rate limit handling."""
|
|
|
|
| 732 |
|
| 733 |
|
| 734 |
for model_id in active_models:
|
| 735 |
+
answer = answers.get(model_id)
|
| 736 |
+
if not answer: # Add guard clause
|
| 737 |
+
log_debug(f"No answer found for model {model_id}. Skipping ranking.")
|
| 738 |
+
continue
|
| 739 |
+
|
| 740 |
if answer == "Error answering": # Handle answer generation errors
|
| 741 |
consecutive_failures[model_id] += 1
|
| 742 |
if consecutive_failures[model_id] >= failure_threshold:
|
|
|
|
| 805 |
results["question_prompt"].append(question_prompt)
|
| 806 |
results["question"].append(question)
|
| 807 |
results["answer"].append(answer)
|
| 808 |
+
results["answer_generation_duration"].append(answer_durations.get(model_id, 0))
|
| 809 |
results["average_rank"].append(average_rank)
|
| 810 |
results["ranks"].append([ranks[m] for m in active_models if m in ranks]) # Store raw ranks including Nones, ensure order
|
| 811 |
results["question_rank_average"].append(question_avg_rank) # Store question rank average
|
|
|
|
| 827 |
total_valid_rank = 0 # Keep track of the sum of valid (non-NaN) ranks
|
| 828 |
|
| 829 |
for m_id in active_models:
|
| 830 |
+
if m_id in cumulative_avg_rank and not np.isnan(cumulative_avg_rank[m_id]):
|
| 831 |
temp_weights[m_id] = cumulative_avg_rank[m_id]
|
| 832 |
total_valid_rank += cumulative_avg_rank[m_id]
|
| 833 |
else: # if cumulative is empty, keep original
|
|
|
|
| 895 |
# Streamlit UI
|
| 896 |
st.title("LLM Benchmark")
|
| 897 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 898 |
# Setup sidebar for configuration
|
| 899 |
st.sidebar.header("Configuration")
|
| 900 |
|
|
|
|
| 977 |
# Clear previous outputs
|
| 978 |
st.session_state.main_output = []
|
| 979 |
st.session_state.debug_output = []
|
| 980 |
+
st.session_state.progress = 0
|
| 981 |
|
| 982 |
if not hf_token:
|
| 983 |
st.error("Please enter your Hugging Face API token")
|
|
|
|
| 1046 |
with tab2:
|
| 1047 |
# Display main output log
|
| 1048 |
st.subheader("Execution Log")
|
|
|
|
| 1049 |
|
| 1050 |
# Display logs
|
| 1051 |
log_text = "\n".join(st.session_state.main_output)
|
| 1052 |
+
st.text_area("Progress Log", log_text, height=400)
|
| 1053 |
|
| 1054 |
# Add a refresh button for the log
|
| 1055 |
+
if st.button("Refresh Progress Log"):
|
| 1056 |
+
pass # The rerun happens automatically at the end
|
| 1057 |
|
| 1058 |
with tab3:
|
| 1059 |
# Display debug output
|
| 1060 |
st.subheader("Debug Log")
|
|
|
|
| 1061 |
|
| 1062 |
# Display debug logs
|
| 1063 |
debug_text = "\n".join(st.session_state.debug_output)
|
| 1064 |
+
st.text_area("Debug Information", debug_text, height=400)
|
| 1065 |
+
|
| 1066 |
+
# Add a refresh button for the debug log
|
| 1067 |
+
if st.button("Refresh Debug Log"):
|
| 1068 |
+
pass # The rerun happens automatically at the end
|
| 1069 |
+
|
| 1070 |
+
# Auto-refresh mechanism
|
| 1071 |
+
if st.session_state.get('update_counter', 0) > 0:
|
| 1072 |
+
time.sleep(0.1) # Brief pause to allow UI to update
|
| 1073 |
+
st.experimental_rerun()
|