rgb / src /streamlit_app.py
aaron-burr's picture
Upload streamlit visualization app
78ef790 verified
import pandas as pd
import pygwalker as pyg
import streamlit as st
from pygwalker.api.streamlit import StreamlitRenderer
import random
st.set_page_config(layout="wide")
def no_op(*args, **kwargs):
"""This function replaces the original one and does not perform any action."""
pass
st.user_info.maybe_show_deprecated_user_warning = no_op
st.header("Leaderboard")
benchmark = st.selectbox(
"Select the type of benchmark you want",
("rabakbench", "toxicchat", "openaimod" ),
)
def get_statistics (df):
tp = df[df['classification']=="True Positive"]["prompt"].nunique()
fp = df[df['classification']=="False Positive"]["prompt"].nunique()
fn = df[df['classification']=="False Negative"]["prompt"].nunique()
avg_time = df['time'].mean()
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return f1, precision, recall, avg_time
df_len = pd.read_csv(f'./src/llamaguard3_8b_{benchmark}.csv').shape[0]
if df_len > 500:
random.seed(42)
end = df_len - 1
random_indices = [random.randint(0, end) for _ in range(500)]
else:
random_indices = list(range(df_len))
## get the statistics for rabakbench for the Llamaguard 3 8b
df_ll = pd.read_csv(f'./src/llamaguard3_8b_{benchmark}.csv').iloc[random_indices]
ll_f1, ll_precision, ll_recall, ll_time = get_statistics (df_ll)
## get the statistics for rabakbench for the Mistral Moderation
df_mm = pd.read_csv(f'./src/mistral_{benchmark}.csv').iloc[random_indices]
mm_f1, mm_precision, mm_recall, mm_time = get_statistics (df_mm)
# ## get the statistics for rabakbench for the Qwen3Guard
df_qw = pd.read_csv(f'./src/qwen3guard_{benchmark}.csv').iloc[random_indices]
qw_f1, qw_precision, qw_recall, qw_time = get_statistics (df_qw)
## display statistics table
df = pd.DataFrame(
[
{"model": "Mistral Moderation", "F1": round(mm_f1,2), "Precision": round(mm_precision,2), "Recall": round(mm_recall,2), "Avg Time/req (s)": round(mm_time,2)},
{"model": "Qwen3Guard", "F1": round(qw_f1,2), "Precision": round(qw_precision,2), "Recall": round(qw_recall,2), "Avg Time/req (s)": round(qw_time,2)},
{"model": "llamaguard 3 8b", "F1": round(ll_f1,2), "Precision": round(ll_precision,2), "Recall": round(ll_recall,2), "Avg Time/req (s)": round(ll_time,2)},
]
)
st.dataframe(df, hide_index= True, use_container_width=True)
st.header("Analysis")
guardrail = "Qwen3Guard"
df_dictionary = {
"Qwen3Guard": df_qw,
"Llama3_8b": df_ll,
"Mistral": df_mm}
c1, c2 = st.columns([4, 3])
with c1:
if df_len > 500:
st.write(f"Analysis of 500 randomly selected samples from {benchmark} benchmark on")
else:
st.write(f"Analysis of full {benchmark} benchmark on")
with c2:
guardrail = st.selectbox(
"model-type",
["Qwen3Guard", "Llama3_8b", "Mistral"],
label_visibility="collapsed"
)
@st.cache_resource
def load_pygwalker(df_dictionary, guardrail):
pygapp = StreamlitRenderer(df_dictionary[guardrail], spec='./src/pygwalker_spec_display.json', scrolling=True)
return pygapp
pygapp = load_pygwalker(df_dictionary, guardrail)
pygapp.explorer()