William Arnold commited on
Commit
48172f6
·
1 Parent(s): bafdc7f

Perf update

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [theme]
2
+ base="light"
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import sys
2
- sys.path.append('./src/')
3
- import rbeval.dash
4
 
 
 
5
 
6
- rbeval.dash.main()
 
 
1
  import sys
 
 
2
 
3
+ sys.path.append("./src/")
4
+ from rbeval.dash import main
5
 
6
+
7
+ main()
requirements.txt CHANGED
@@ -4,4 +4,4 @@ huggingface-hub>=0.24.2
4
  tqdm>=4.66.4
5
  numpy>=1.26.4
6
  dacite>=1.8.1
7
- seaborn>=0.13.1
 
4
  tqdm>=4.66.4
5
  numpy>=1.26.4
6
  dacite>=1.8.1
7
+ seaborn>=0.13.1
src/rbeval/dash.py CHANGED
@@ -7,6 +7,8 @@ from dacite import from_dict
7
 
8
  from rbeval.plot.data import EvalGroup, get_samples
9
  from rbeval.plot.score_cdf import (
 
 
10
  plot_with_data,
11
  get_plot_data,
12
  plot_cfgs,
@@ -24,7 +26,9 @@ def cached_samples(dir: Path, name_filter: Optional[str]) -> List[EvalGroup]:
24
 
25
 
26
  @st.cache_data
27
- def cached_score_cdf(dir: Path, name_filter: Optional[str]):
 
 
28
  samples = cached_samples(dir, name_filter)
29
  cfgs = plot_cfgs()
30
  data = [get_plot_data(cfg, samples) for cfg in cfgs]
@@ -43,17 +47,41 @@ def cache_compare(
43
  return grouped_dict, base_name, comp_name
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def main():
47
  parser = argparse.ArgumentParser(description="rbeval dashboard")
48
  parser.add_argument("--evals", type=str, default="./lmo-fake", required=False)
49
  args, _rest = parser.parse_known_args()
50
  eval_dir = Path(args.evals)
51
  # Show all the models
 
52
  st.set_page_config(layout="wide")
53
  score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
54
- for data, cfg in zip(score_cdf_data, cfgs):
55
- figs = plot_with_data(cfg, data)
56
- with st.expander(cfg.name):
 
 
 
 
 
 
 
 
 
57
  for fig in figs:
58
  st.altair_chart(fig.chart) # type: ignore
59
 
@@ -64,23 +92,27 @@ def main():
64
  for m in group.model_evals
65
  ]
66
  )
67
- base_model = st.selectbox("Base model", model_names)
68
- compare_model = st.selectbox("Compare model", model_names)
69
- st.text(f"Comparing {base_model} with {compare_model}")
70
- if base_model and compare_model:
71
- if base_model == compare_model:
72
- st.text("Base and compare models are the same")
73
- return
74
- grouped, base_name, comp_name = cache_compare(
75
- eval_dir, None, base_model, compare_model
76
- )
77
- grouped = {
78
- k: [from_dict(model_comp.Scores, vi) for vi in v]
79
- for k, v in grouped.items()
80
- }
81
- for fig in model_comp.get_figures(grouped, base_name, comp_name):
82
- st.text(fig.name)
83
- st.altair_chart(fig.chart) # type: ignore
 
 
 
 
84
 
85
 
86
  if __name__ == "__main__":
 
7
 
8
  from rbeval.plot.data import EvalGroup, get_samples
9
  from rbeval.plot.score_cdf import (
10
+ CdfPlotConfig,
11
+ PlotData,
12
  plot_with_data,
13
  get_plot_data,
14
  plot_cfgs,
 
26
 
27
 
28
  @st.cache_data
29
+ def cached_score_cdf(
30
+ dir: Path, name_filter: Optional[str]
31
+ ) -> tuple[List[PlotData], List[CdfPlotConfig]]:
32
  samples = cached_samples(dir, name_filter)
33
  cfgs = plot_cfgs()
34
  data = [get_plot_data(cfg, samples) for cfg in cfgs]
 
47
  return grouped_dict, base_name, comp_name
48
 
49
 
50
+ def filter_for_group(data: List[PlotData], group: str) -> List[PlotData]:
51
+ return [
52
+ PlotData(
53
+ renorm=[df for df in d.renorm if df["group"].iloc[0] == group],
54
+ norenorm=[df for df in d.norenorm if df["group"].iloc[0] == group],
55
+ )
56
+ for d in data
57
+ ]
58
+
59
+
60
+ def get_group_names(data: List[PlotData]) -> List[str]:
61
+ return sorted(set([df["group"].iloc[0] for d in data for df in d.renorm]))
62
+
63
+
64
  def main():
65
  parser = argparse.ArgumentParser(description="rbeval dashboard")
66
  parser.add_argument("--evals", type=str, default="./lmo-fake", required=False)
67
  args, _rest = parser.parse_known_args()
68
  eval_dir = Path(args.evals)
69
  # Show all the models
70
+
71
  st.set_page_config(layout="wide")
72
  score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
73
+ group_names = sorted([g.name for g in cached_samples(eval_dir, None)])
74
+ renormed = st.toggle("Renormalize Probabilities", True)
75
+
76
+ st.subheader("Model Performance Curves")
77
+ for group in group_names:
78
+ group_data = filter_for_group(score_cdf_data, group)
79
+ with st.expander(group):
80
+ figs = [
81
+ fig
82
+ for data, cdf in zip(group_data, cfgs)
83
+ for fig in plot_with_data(cdf, data, renormed)
84
+ ]
85
  for fig in figs:
86
  st.altair_chart(fig.chart) # type: ignore
87
 
 
92
  for m in group.model_evals
93
  ]
94
  )
95
+ with st.form("comp"):
96
+ st.subheader("Model Comparison Tool")
97
+ base_model = st.selectbox("Base model", model_names)
98
+ compare_model = st.selectbox("Compare model", model_names)
99
+ st.text(f"Comparing {base_model} with {compare_model}")
100
+ submitted = st.form_submit_button("Compare")
101
+ if base_model and compare_model and submitted:
102
+ print("Computing comparisons")
103
+ if base_model == compare_model:
104
+ st.text("Base and compare models are the same")
105
+ return
106
+ grouped, base_name, comp_name = cache_compare(
107
+ eval_dir, None, base_model, compare_model
108
+ )
109
+ grouped = {
110
+ k: [from_dict(model_comp.Scores, vi) for vi in v]
111
+ for k, v in grouped.items()
112
+ }
113
+ for fig in model_comp.get_figures(grouped, base_name, comp_name):
114
+ st.text(fig.name)
115
+ st.altair_chart(fig.chart) # type: ignore
116
 
117
 
118
  if __name__ == "__main__":
src/rbeval/plot/data.py CHANGED
@@ -151,3 +151,4 @@ class Figure:
151
  | alt.ConcatChart
152
  | alt.VConcatChart
153
  )
 
 
151
  | alt.ConcatChart
152
  | alt.VConcatChart
153
  )
154
+ group: Optional[str] = None
src/rbeval/plot/score_cdf.py CHANGED
@@ -25,7 +25,8 @@ def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
25
  return [
26
  a
27
  for cfg in plot_cfgs()
28
- for a in plot_with_data(cfg, get_plot_data(cfg, samples))
 
29
  ]
30
 
31
 
@@ -59,32 +60,36 @@ def get_plot_data(
59
  def plot_with_data(
60
  cfg: "CdfPlotConfig",
61
  data: PlotData,
 
62
  ) -> List[Figure]:
63
  figures: List[Figure] = []
64
- for renorm, group_dfs in zip([True, False], [data.renorm, data.norenorm]):
65
- for df in group_dfs:
66
- group_name: str = str(df["group"].iloc[0]) # type: ignore
67
- selection = alt.selection_point(fields=["label"], bind="legend") # type: ignore
68
- chart = (
69
- alt.Chart(df) # type: ignore
70
- .mark_line()
71
- .encode(
72
- x=alt.X("x:Q", title=cfg.xlabel),
73
- y=alt.Y("y:Q", title=cfg.ylabel),
74
- color=alt.Color("label:N", legend=alt.Legend(symbolOpacity=1.0)),
75
- opacity=alt.condition( # type: ignore
76
- selection,
77
- alt.Opacity("fewshot:O"),
78
- alt.value(0.1), # type: ignore
79
- ),
80
- )
81
- .properties(title=cfg.title(group_name, renorm), width=800, height=400)
82
- .resolve_legend(color="independent")
83
- .resolve_axis(y="independent", x="independent")
84
- .add_params(selection)
85
- .interactive()
86
  )
87
- figures.append(Figure(name=f"{group_name} {cfg.name}", chart=chart))
 
 
 
 
 
 
 
 
88
 
89
  return figures
90
 
 
25
  return [
26
  a
27
  for cfg in plot_cfgs()
28
+ for renorm in [True, False]
29
+ for a in plot_with_data(cfg, get_plot_data(cfg, samples), renorm)
30
  ]
31
 
32
 
 
60
  def plot_with_data(
61
  cfg: "CdfPlotConfig",
62
  data: PlotData,
63
+ renorm: bool = True,
64
  ) -> List[Figure]:
65
  figures: List[Figure] = []
66
+ group_dfs = data.renorm if renorm else data.norenorm
67
+ for df in group_dfs:
68
+ group_name: str = str(df["group"].iloc[0]) # type: ignore
69
+ label_selection = alt.selection_point(fields=["label"], bind="legend") # type: ignore
70
+ fs_selection = alt.selection_point(fields=["fewshot"], bind="legend") # type: ignore
71
+ chart = (
72
+ alt.Chart(df) # type: ignore
73
+ .mark_line()
74
+ .encode(
75
+ x=alt.X("x:Q", title=cfg.xlabel),
76
+ y=alt.Y("y:Q", title=cfg.ylabel),
77
+ color=alt.Color("label:N", legend=alt.Legend(symbolOpacity=1.0)),
78
+ opacity=alt.condition( # type: ignore
79
+ label_selection & fs_selection,
80
+ alt.Opacity("fewshot:O"),
81
+ alt.value(0.0), # type: ignore
82
+ ),
 
 
 
 
 
83
  )
84
+ .properties(title=cfg.title(group_name, renorm), width=800, height=400)
85
+ .resolve_legend(color="independent")
86
+ .resolve_axis(y="independent", x="independent")
87
+ .add_params(fs_selection, label_selection)
88
+ .interactive()
89
+ )
90
+ figures.append(
91
+ Figure(name=f"{group_name} {cfg.name}", chart=chart, group=group_name)
92
+ )
93
 
94
  return figures
95