William Arnold commited on
Commit ·
48172f6
1
Parent(s): bafdc7f
Perf update
Browse files- .streamlit/config.toml +2 -0
- app.py +4 -3
- requirements.txt +1 -1
- src/rbeval/dash.py +53 -21
- src/rbeval/plot/data.py +1 -0
- src/rbeval/plot/score_cdf.py +29 -24
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[theme]
|
| 2 |
+
base="light"
|
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import sys
|
| 2 |
-
sys.path.append('./src/')
|
| 3 |
-
import rbeval.dash
|
| 4 |
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
| 1 |
import sys
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
sys.path.append("./src/")
|
| 4 |
+
from rbeval.dash import main
|
| 5 |
|
| 6 |
+
|
| 7 |
+
main()
|
requirements.txt
CHANGED
|
@@ -4,4 +4,4 @@ huggingface-hub>=0.24.2
|
|
| 4 |
tqdm>=4.66.4
|
| 5 |
numpy>=1.26.4
|
| 6 |
dacite>=1.8.1
|
| 7 |
-
seaborn>=0.13.1
|
|
|
|
| 4 |
tqdm>=4.66.4
|
| 5 |
numpy>=1.26.4
|
| 6 |
dacite>=1.8.1
|
| 7 |
+
seaborn>=0.13.1
|
src/rbeval/dash.py
CHANGED
|
@@ -7,6 +7,8 @@ from dacite import from_dict
|
|
| 7 |
|
| 8 |
from rbeval.plot.data import EvalGroup, get_samples
|
| 9 |
from rbeval.plot.score_cdf import (
|
|
|
|
|
|
|
| 10 |
plot_with_data,
|
| 11 |
get_plot_data,
|
| 12 |
plot_cfgs,
|
|
@@ -24,7 +26,9 @@ def cached_samples(dir: Path, name_filter: Optional[str]) -> List[EvalGroup]:
|
|
| 24 |
|
| 25 |
|
| 26 |
@st.cache_data
|
| 27 |
-
def cached_score_cdf(
|
|
|
|
|
|
|
| 28 |
samples = cached_samples(dir, name_filter)
|
| 29 |
cfgs = plot_cfgs()
|
| 30 |
data = [get_plot_data(cfg, samples) for cfg in cfgs]
|
|
@@ -43,17 +47,41 @@ def cache_compare(
|
|
| 43 |
return grouped_dict, base_name, comp_name
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def main():
|
| 47 |
parser = argparse.ArgumentParser(description="rbeval dashboard")
|
| 48 |
parser.add_argument("--evals", type=str, default="./lmo-fake", required=False)
|
| 49 |
args, _rest = parser.parse_known_args()
|
| 50 |
eval_dir = Path(args.evals)
|
| 51 |
# Show all the models
|
|
|
|
| 52 |
st.set_page_config(layout="wide")
|
| 53 |
score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
for fig in figs:
|
| 58 |
st.altair_chart(fig.chart) # type: ignore
|
| 59 |
|
|
@@ -64,23 +92,27 @@ def main():
|
|
| 64 |
for m in group.model_evals
|
| 65 |
]
|
| 66 |
)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
if __name__ == "__main__":
|
|
|
|
| 7 |
|
| 8 |
from rbeval.plot.data import EvalGroup, get_samples
|
| 9 |
from rbeval.plot.score_cdf import (
|
| 10 |
+
CdfPlotConfig,
|
| 11 |
+
PlotData,
|
| 12 |
plot_with_data,
|
| 13 |
get_plot_data,
|
| 14 |
plot_cfgs,
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
@st.cache_data
|
| 29 |
+
def cached_score_cdf(
|
| 30 |
+
dir: Path, name_filter: Optional[str]
|
| 31 |
+
) -> tuple[List[PlotData], List[CdfPlotConfig]]:
|
| 32 |
samples = cached_samples(dir, name_filter)
|
| 33 |
cfgs = plot_cfgs()
|
| 34 |
data = [get_plot_data(cfg, samples) for cfg in cfgs]
|
|
|
|
| 47 |
return grouped_dict, base_name, comp_name
|
| 48 |
|
| 49 |
|
| 50 |
+
def filter_for_group(data: List[PlotData], group: str) -> List[PlotData]:
|
| 51 |
+
return [
|
| 52 |
+
PlotData(
|
| 53 |
+
renorm=[df for df in d.renorm if df["group"].iloc[0] == group],
|
| 54 |
+
norenorm=[df for df in d.norenorm if df["group"].iloc[0] == group],
|
| 55 |
+
)
|
| 56 |
+
for d in data
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_group_names(data: List[PlotData]) -> List[str]:
|
| 61 |
+
return sorted(set([df["group"].iloc[0] for d in data for df in d.renorm]))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
def main():
|
| 65 |
parser = argparse.ArgumentParser(description="rbeval dashboard")
|
| 66 |
parser.add_argument("--evals", type=str, default="./lmo-fake", required=False)
|
| 67 |
args, _rest = parser.parse_known_args()
|
| 68 |
eval_dir = Path(args.evals)
|
| 69 |
# Show all the models
|
| 70 |
+
|
| 71 |
st.set_page_config(layout="wide")
|
| 72 |
score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
|
| 73 |
+
group_names = sorted([g.name for g in cached_samples(eval_dir, None)])
|
| 74 |
+
renormed = st.toggle("Renormalize Probabilities", True)
|
| 75 |
+
|
| 76 |
+
st.subheader("Model Performance Curves")
|
| 77 |
+
for group in group_names:
|
| 78 |
+
group_data = filter_for_group(score_cdf_data, group)
|
| 79 |
+
with st.expander(group):
|
| 80 |
+
figs = [
|
| 81 |
+
fig
|
| 82 |
+
for data, cdf in zip(group_data, cfgs)
|
| 83 |
+
for fig in plot_with_data(cdf, data, renormed)
|
| 84 |
+
]
|
| 85 |
for fig in figs:
|
| 86 |
st.altair_chart(fig.chart) # type: ignore
|
| 87 |
|
|
|
|
| 92 |
for m in group.model_evals
|
| 93 |
]
|
| 94 |
)
|
| 95 |
+
with st.form("comp"):
|
| 96 |
+
st.subheader("Model Comparison Tool")
|
| 97 |
+
base_model = st.selectbox("Base model", model_names)
|
| 98 |
+
compare_model = st.selectbox("Compare model", model_names)
|
| 99 |
+
st.text(f"Comparing {base_model} with {compare_model}")
|
| 100 |
+
submitted = st.form_submit_button("Compare")
|
| 101 |
+
if base_model and compare_model and submitted:
|
| 102 |
+
print("Computing comparisons")
|
| 103 |
+
if base_model == compare_model:
|
| 104 |
+
st.text("Base and compare models are the same")
|
| 105 |
+
return
|
| 106 |
+
grouped, base_name, comp_name = cache_compare(
|
| 107 |
+
eval_dir, None, base_model, compare_model
|
| 108 |
+
)
|
| 109 |
+
grouped = {
|
| 110 |
+
k: [from_dict(model_comp.Scores, vi) for vi in v]
|
| 111 |
+
for k, v in grouped.items()
|
| 112 |
+
}
|
| 113 |
+
for fig in model_comp.get_figures(grouped, base_name, comp_name):
|
| 114 |
+
st.text(fig.name)
|
| 115 |
+
st.altair_chart(fig.chart) # type: ignore
|
| 116 |
|
| 117 |
|
| 118 |
if __name__ == "__main__":
|
src/rbeval/plot/data.py
CHANGED
|
@@ -151,3 +151,4 @@ class Figure:
|
|
| 151 |
| alt.ConcatChart
|
| 152 |
| alt.VConcatChart
|
| 153 |
)
|
|
|
|
|
|
| 151 |
| alt.ConcatChart
|
| 152 |
| alt.VConcatChart
|
| 153 |
)
|
| 154 |
+
group: Optional[str] = None
|
src/rbeval/plot/score_cdf.py
CHANGED
|
@@ -25,7 +25,8 @@ def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
|
|
| 25 |
return [
|
| 26 |
a
|
| 27 |
for cfg in plot_cfgs()
|
| 28 |
-
for
|
|
|
|
| 29 |
]
|
| 30 |
|
| 31 |
|
|
@@ -59,32 +60,36 @@ def get_plot_data(
|
|
| 59 |
def plot_with_data(
|
| 60 |
cfg: "CdfPlotConfig",
|
| 61 |
data: PlotData,
|
|
|
|
| 62 |
) -> List[Figure]:
|
| 63 |
figures: List[Figure] = []
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
),
|
| 80 |
-
)
|
| 81 |
-
.properties(title=cfg.title(group_name, renorm), width=800, height=400)
|
| 82 |
-
.resolve_legend(color="independent")
|
| 83 |
-
.resolve_axis(y="independent", x="independent")
|
| 84 |
-
.add_params(selection)
|
| 85 |
-
.interactive()
|
| 86 |
)
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
return figures
|
| 90 |
|
|
|
|
| 25 |
return [
|
| 26 |
a
|
| 27 |
for cfg in plot_cfgs()
|
| 28 |
+
for renorm in [True, False]
|
| 29 |
+
for a in plot_with_data(cfg, get_plot_data(cfg, samples), renorm)
|
| 30 |
]
|
| 31 |
|
| 32 |
|
|
|
|
| 60 |
def plot_with_data(
|
| 61 |
cfg: "CdfPlotConfig",
|
| 62 |
data: PlotData,
|
| 63 |
+
renorm: bool = True,
|
| 64 |
) -> List[Figure]:
|
| 65 |
figures: List[Figure] = []
|
| 66 |
+
group_dfs = data.renorm if renorm else data.norenorm
|
| 67 |
+
for df in group_dfs:
|
| 68 |
+
group_name: str = str(df["group"].iloc[0]) # type: ignore
|
| 69 |
+
label_selection = alt.selection_point(fields=["label"], bind="legend") # type: ignore
|
| 70 |
+
fs_selection = alt.selection_point(fields=["fewshot"], bind="legend") # type: ignore
|
| 71 |
+
chart = (
|
| 72 |
+
alt.Chart(df) # type: ignore
|
| 73 |
+
.mark_line()
|
| 74 |
+
.encode(
|
| 75 |
+
x=alt.X("x:Q", title=cfg.xlabel),
|
| 76 |
+
y=alt.Y("y:Q", title=cfg.ylabel),
|
| 77 |
+
color=alt.Color("label:N", legend=alt.Legend(symbolOpacity=1.0)),
|
| 78 |
+
opacity=alt.condition( # type: ignore
|
| 79 |
+
label_selection & fs_selection,
|
| 80 |
+
alt.Opacity("fewshot:O"),
|
| 81 |
+
alt.value(0.0), # type: ignore
|
| 82 |
+
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
)
|
| 84 |
+
.properties(title=cfg.title(group_name, renorm), width=800, height=400)
|
| 85 |
+
.resolve_legend(color="independent")
|
| 86 |
+
.resolve_axis(y="independent", x="independent")
|
| 87 |
+
.add_params(fs_selection, label_selection)
|
| 88 |
+
.interactive()
|
| 89 |
+
)
|
| 90 |
+
figures.append(
|
| 91 |
+
Figure(name=f"{group_name} {cfg.name}", chart=chart, group=group_name)
|
| 92 |
+
)
|
| 93 |
|
| 94 |
return figures
|
| 95 |
|