William Arnold commited on
Commit
ca11d0f
·
1 Parent(s): e3228e0

Ready for spaces

Browse files
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: rbeval
3
+ emoji: 💩
4
+ colorFrom: yellow
5
+ colorTo: orange
6
+ sdk: streamlit
7
+ sdk_version: 1.25.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
app.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import rbeval.dash
2
+
3
+ rbeval.dash.main()
pyproject.toml CHANGED
@@ -7,6 +7,12 @@ name = "rbeval"
7
  requires-python = ">=3.8"
8
  dynamic = ["version"]
9
  dependencies = [
 
 
 
 
 
 
10
  "seaborn>=0.13.2"
11
  ]
12
 
 
7
  requires-python = ">=3.8"
8
  dynamic = ["version"]
9
  dependencies = [
10
+ "pandas>=2.2.2",
11
+ "matplotlib>=3.9.1",
12
+ "huggingface-hub>=0.24.2",
13
+ "tqdm>=4.66.4",
14
+ "numpy>=1.26.4",
15
+ "dacite>=1.8.1",
16
  "seaborn>=0.13.2"
17
  ]
18
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pandas>=2.2.2
2
+ matplotlib>=3.9.1
3
+ huggingface-hub>=0.24.2
4
+ tqdm>=4.66.4
5
+ numpy>=1.26.4
6
+ dacite>=1.8.1
7
+ seaborn>=0.13.1
8
+ .
src/rbeval/dash.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Optional
3
+ import streamlit as st
4
+ import argparse
5
+
6
+ from rbeval.plot.data import EvalGroup, get_samples
7
+ from rbeval.plot.score_cdf_altair import (
8
+ plot_with_data,
9
+ get_plot_data,
10
+ plot_cfgs,
11
+ )
12
+ from rbeval.plot import model_comp
13
+ from huggingface_hub import snapshot_download
14
+
15
+
16
+ @st.cache_resource
17
+ def cached_samples(dir: Path, name_filter: Optional[str]) -> List[EvalGroup]:
18
+ if not dir.exists():
19
+ dir = Path(snapshot_download("mli-will/rbeval"))
20
+ samples = get_samples(dir, name_filter)
21
+ return samples
22
+
23
+
24
+ @st.cache_data
25
+ def cached_score_cdf(dir, name_filter):
26
+ samples = cached_samples(dir, name_filter)
27
+ cfgs = plot_cfgs()
28
+ data = [get_plot_data(cfg, samples) for cfg in cfgs]
29
+ return data, cfgs
30
+
31
+
32
+ @st.cache_data
33
+ def cache_compare(dir, name_filter, base_name, compare_name):
34
+ samples = cached_samples(dir, name_filter)
35
+ grouped, base_name, comp_name = model_comp.get_scores(
36
+ samples, base_name + "$", compare_name + "$"
37
+ )
38
+ grouped_dict = {k: [vi.to_dict() for vi in v] for k, v in grouped.items()}
39
+ return grouped_dict, base_name, comp_name
40
+
41
+
42
+ def main():
43
+ parser = argparse.ArgumentParser(description="rbeval dashboard")
44
+ parser.add_argument("eval_dir", type=str)
45
+ args, rest = parser.parse_known_args()
46
+ eval_dir = Path(args.eval_dir)
47
+ # Show all the models
48
+ score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
49
+ for data, cfg in zip(score_cdf_data, cfgs):
50
+ figs = plot_with_data(cfg, data)
51
+ with st.expander(cfg.name):
52
+ for fig in figs:
53
+ st.altair_chart(fig.chart)
54
+
55
+ model_names = set(
56
+ [
57
+ m.model_name
58
+ for group in cached_samples(eval_dir, None)
59
+ for m in group.model_evals
60
+ ]
61
+ )
62
+ base_model = st.selectbox("Base model", model_names)
63
+ compare_model = st.selectbox("Compare model", model_names)
64
+ st.write(f"Comparing {base_model} with {compare_model}")
65
+ if base_model and compare_model:
66
+ if base_model == compare_model:
67
+ st.write("Base and compare models are the same")
68
+ return
69
+ grouped, base_name, comp_name = cache_compare(
70
+ eval_dir, None, base_model, compare_model
71
+ )
72
+ grouped = {
73
+ k: [model_comp.Scores.from_dict(vi) for vi in v] for k, v in grouped.items()
74
+ }
75
+ for fig in model_comp.get_figures(grouped, base_name, comp_name):
76
+ st.write(fig.name)
77
+ st.altair_chart(fig.chart)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
src/rbeval/dash/__main__.py DELETED
File without changes
src/rbeval/plot/data.py CHANGED
@@ -7,6 +7,7 @@ from typing import Dict, List, Optional
7
  from collections import defaultdict
8
  import altair as alt
9
 
 
10
  import numpy as np
11
  from tqdm import tqdm
12
 
@@ -26,18 +27,21 @@ def get_samples(inp: Path, name_filter: Optional[str]) -> List["EvalGroup"]:
26
  print(f"Skipping spec {spec_file.stem}")
27
  continue
28
 
29
- group = groups.setdefault(spec.group, EvalGroup(name=spec.group))
30
- model_eval = ModelEval(eval_spec=spec)
31
- group.model_evals.append(model_eval)
32
- for samples_file in (spec_file.parent / spec_file.stem).glob(
33
- "**/samples_*.json*"
34
- ):
35
- cache_file = samples_file.with_suffix(".npy")
36
- if samples_file.with_suffix(".npy").exists():
37
- model_eval.evals.append(
38
- Eval(**np.load(str(cache_file), allow_pickle=True).item())
39
- )
40
- else:
 
 
 
41
  with open(samples_file, "r") as f:
42
  if samples_file.suffix == ".jsonl":
43
  docs = [json.loads(s) for s in f.readlines()]
@@ -57,8 +61,8 @@ def get_samples(inp: Path, name_filter: Optional[str]) -> List["EvalGroup"]:
57
  cor_logprobs=np.array(cor_logprobs),
58
  inc_logprobs=np.array(inc_logprobs),
59
  )
60
- np.save(str(cache_file), asdict(eval)) # type: ignore
61
  model_eval.evals.append(eval)
 
62
 
63
  return list(groups.values())
64
 
 
7
  from collections import defaultdict
8
  import altair as alt
9
 
10
+ from dacite import from_dict
11
  import numpy as np
12
  from tqdm import tqdm
13
 
 
27
  print(f"Skipping spec {spec_file.stem}")
28
  continue
29
 
30
+ group_cache_file = Path(
31
+ spec_file.with_stem(spec_file.stem + "_group_cache")
32
+ ).with_suffix(".npy")
33
+ if group_cache_file.exists():
34
+ res_dict = np.load(str(group_cache_file), allow_pickle=True).item()
35
+ group = from_dict(data_class=EvalGroup, data=res_dict)
36
+ groups[group.name] = group
37
+ continue
38
+ else:
39
+ group = groups.setdefault(spec.group, EvalGroup(name=spec.group))
40
+ model_eval = ModelEval(eval_spec=spec)
41
+ group.model_evals.append(model_eval)
42
+ for samples_file in (spec_file.parent / spec_file.stem).glob(
43
+ "**/samples_*.json*"
44
+ ):
45
  with open(samples_file, "r") as f:
46
  if samples_file.suffix == ".jsonl":
47
  docs = [json.loads(s) for s in f.readlines()]
 
61
  cor_logprobs=np.array(cor_logprobs),
62
  inc_logprobs=np.array(inc_logprobs),
63
  )
 
64
  model_eval.evals.append(eval)
65
+ np.save(str(group_cache_file), asdict(group)) # type: ignore
66
 
67
  return list(groups.values())
68
 
src/rbeval/plot/model_comp.py CHANGED
@@ -2,7 +2,7 @@ import argparse
2
  import altair as alt
3
  import pandas as pd
4
  from collections import defaultdict
5
- from dataclasses import dataclass, field
6
  import itertools
7
  from typing import Dict, List, Optional
8
  import warnings
@@ -20,21 +20,18 @@ class Scores:
20
  cor_minus_inc_samples: List[np.ndarray] = field(default_factory=list)
21
  cor_samples: List[np.ndarray] = field(default_factory=list)
22
 
 
 
23
 
24
- def model_comparer(samples: List[EvalGroup], rem_args: List[str]) -> List[Figure]:
25
- parser = argparse.ArgumentParser()
26
- parser.add_argument("--base", type=str)
27
- parser.add_argument("--compare", type=str)
28
- args = parser.parse_args(rem_args)
29
- base_name_filt: Optional[str] = args.base
30
- comp_name_filt: Optional[str] = args.compare
31
 
32
- if base_name_filt is None or comp_name_filt is None:
33
- warnings.warn(
34
- "Skipping model comparison plot, need to specify base and compare"
35
- )
36
- return []
37
 
 
 
 
38
  bases: List[ModelEval] = list(
39
  itertools.chain.from_iterable(
40
  g.collect_with_name(base_name_filt) for g in samples
@@ -107,6 +104,10 @@ def model_comparer(samples: List[EvalGroup], rem_args: List[str]) -> List[Figure
107
  for title, scores in scores_by_mask.items():
108
  grouped[title].append(scores)
109
 
 
 
 
 
110
  cmp_name = f"{base_name} to {comp_name}"
111
  return [
112
  Figure(name=f"{cmp_name} prob diff perf curves", chart=plot_diff_cdf(grouped)),
@@ -115,6 +116,24 @@ def model_comparer(samples: List[EvalGroup], rem_args: List[str]) -> List[Figure
115
  ]
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def plot_diff_cdf(grouped: Dict[str, List[Scores]]) -> alt.HConcatChart:
119
  charts = []
120
  for title, score_list in grouped.items():
 
2
  import altair as alt
3
  import pandas as pd
4
  from collections import defaultdict
5
+ from dataclasses import asdict, dataclass, field
6
  import itertools
7
  from typing import Dict, List, Optional
8
  import warnings
 
20
  cor_minus_inc_samples: List[np.ndarray] = field(default_factory=list)
21
  cor_samples: List[np.ndarray] = field(default_factory=list)
22
 
23
+ def to_dict(self):
24
+ return asdict(self)
25
 
26
+ @classmethod
27
+ def from_dict(cls, d: dict):
28
+ d["spec"] = EvalSpec(**d["spec"])
29
+ return cls(**d)
 
 
 
30
 
 
 
 
 
 
31
 
32
+ def get_scores(
33
+ samples: List[EvalGroup], base_name_filt: str, comp_name_filt: str
34
+ ) -> tuple[Dict[str, List[Scores]], str, str]:
35
  bases: List[ModelEval] = list(
36
  itertools.chain.from_iterable(
37
  g.collect_with_name(base_name_filt) for g in samples
 
104
  for title, scores in scores_by_mask.items():
105
  grouped[title].append(scores)
106
 
107
+ return grouped, base_name, comp_name
108
+
109
+
110
+ def get_figures(grouped: Dict[str, List[Scores]], base_name, comp_name) -> List[Figure]:
111
  cmp_name = f"{base_name} to {comp_name}"
112
  return [
113
  Figure(name=f"{cmp_name} prob diff perf curves", chart=plot_diff_cdf(grouped)),
 
116
  ]
117
 
118
 
119
+ def model_comparer(samples: List[EvalGroup], rem_args: List[str]) -> List[Figure]:
120
+ parser = argparse.ArgumentParser()
121
+ parser.add_argument("--base", type=str)
122
+ parser.add_argument("--compare", type=str)
123
+ args = parser.parse_args(rem_args)
124
+ base_name_filt: Optional[str] = args.base
125
+ comp_name_filt: Optional[str] = args.compare
126
+
127
+ if base_name_filt is None or comp_name_filt is None:
128
+ warnings.warn(
129
+ "Skipping model comparison plot, need to specify base and compare"
130
+ )
131
+ return []
132
+
133
+ grouped, base_name, comp_name = get_scores(samples, base_name_filt, comp_name_filt)
134
+ return get_figures(grouped, base_name, comp_name)
135
+
136
+
137
  def plot_diff_cdf(grouped: Dict[str, List[Scores]]) -> alt.HConcatChart:
138
  charts = []
139
  for title, score_list in grouped.items():
src/rbeval/plot/score_cdf_altair.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import List
2
 
3
  from rbeval.plot.data import Eval, EvalGroup, Figure
@@ -9,25 +10,31 @@ import pandas as pd
9
  from rbeval.plot.utils import CdfData, renormed
10
 
11
 
 
 
 
 
 
 
 
 
 
 
12
  def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
13
  return [
14
- Figure(
15
- name="Correct Prob Perf Curve",
16
- chart=plot_with_config(CorrectProbCdfPlot(), samples),
17
- ),
18
- Figure(
19
- name="Corr-Incorr Gap Perf Curve",
20
- chart=plot_with_config(CorrIncorrDiffConfig(), samples),
21
- ),
22
  ]
23
 
24
 
25
- def plot_with_config(
26
  cfg: "CdfPlotConfig",
27
  samples: List[EvalGroup],
28
- ) -> alt.ConcatChart:
29
- group_dfs = []
30
  for renorm in [True, False]:
 
31
  for group in samples:
32
  dfs = []
33
  for m in group.model_evals:
@@ -38,43 +45,52 @@ def plot_with_config(
38
  "x": cdf.scores,
39
  "y": cdf.cdf_p,
40
  "label": m.model_name,
 
41
  "renorm": renorm,
42
  "fewshot": spec.fewshot,
43
  }
44
  )
45
  dfs.append(df)
46
- group_dfs.append(pd.concat(dfs))
47
-
48
- selection = alt.selection_point(fields=["label"], bind="legend")
49
- charts = []
50
- for group, df in zip(samples, group_dfs):
51
- chart = (
52
- alt.Chart(df)
53
- .mark_line()
54
- .encode(
55
- x=alt.X("x:Q", title=cfg.xlabel),
56
- y=alt.Y("y:Q", title=cfg.ylabel),
57
- color=alt.Color("label:N", legend=alt.Legend(symbolOpacity=1.0)),
58
- opacity=alt.condition(
59
- selection, alt.Opacity("fewshot:O"), alt.value(0.1)
60
- ),
61
- )
62
- .properties(title=cfg.title(group.name, renorm))
63
- .resolve_legend(color="independent")
64
- )
65
 
66
- charts.append(chart)
67
 
68
- final_chart = (
69
- alt.concat(*charts, columns=len(samples)).add_params(selection).interactive()
70
- )
71
- return final_chart
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
 
74
  class CdfPlotConfig(ABC):
75
  plot_type: str
76
  xlabel: str
77
  ylabel: str
 
78
 
79
  @abstractmethod
80
  def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "CdfData":
@@ -92,6 +108,8 @@ class CdfPlotConfig(ABC):
92
 
93
 
94
  class CorrectProbCdfPlot(CdfPlotConfig):
 
 
95
  def __init__(self):
96
  self.plot_type = "corr perf plot"
97
  self.xlabel = "Correct answer probability"
@@ -112,6 +130,8 @@ class CorrectProbCdfPlot(CdfPlotConfig):
112
 
113
 
114
  class CorrIncorrDiffConfig(CdfPlotConfig):
 
 
115
  def __init__(self):
116
  self.plot_type = "corr-max(incor) perf plot"
117
  self.xlabel = "corr prob - max(incor prob)"
 
1
+ from dataclasses import dataclass, field
2
  from typing import List
3
 
4
  from rbeval.plot.data import Eval, EvalGroup, Figure
 
10
  from rbeval.plot.utils import CdfData, renormed
11
 
12
 
13
+ @dataclass
14
+ class PlotData:
15
+ renorm: List[pd.DataFrame] = field(default_factory=list)
16
+ norenorm: List[pd.DataFrame] = field(default_factory=list)
17
+
18
+
19
+ def plot_cfgs():
20
+ return [CorrectProbCdfPlot(), CorrIncorrDiffConfig()]
21
+
22
+
23
  def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
24
  return [
25
+ a
26
+ for cfg in plot_cfgs()
27
+ for a in plot_with_data(cfg, get_plot_data(cfg, samples))
 
 
 
 
 
28
  ]
29
 
30
 
31
+ def get_plot_data(
32
  cfg: "CdfPlotConfig",
33
  samples: List[EvalGroup],
34
+ ) -> PlotData:
35
+ data = PlotData()
36
  for renorm in [True, False]:
37
+ gfs = data.renorm if renorm else data.norenorm
38
  for group in samples:
39
  dfs = []
40
  for m in group.model_evals:
 
45
  "x": cdf.scores,
46
  "y": cdf.cdf_p,
47
  "label": m.model_name,
48
+ "group": group.name,
49
  "renorm": renorm,
50
  "fewshot": spec.fewshot,
51
  }
52
  )
53
  dfs.append(df)
54
+ gfs.append(pd.concat(dfs))
55
+ return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
57
 
58
+ def plot_with_data(
59
+ cfg: "CdfPlotConfig",
60
+ data: PlotData,
61
+ ) -> List[Figure]:
62
+ figures = []
63
+ for renorm, group_dfs in zip([True, False], [data.renorm, data.norenorm]):
64
+ for df in group_dfs:
65
+ group_name = df["group"].iloc[0]
66
+ selection = alt.selection_point(fields=["label"], bind="legend")
67
+ chart = (
68
+ alt.Chart(df)
69
+ .mark_line()
70
+ .encode(
71
+ x=alt.X("x:Q", title=cfg.xlabel),
72
+ y=alt.Y("y:Q", title=cfg.ylabel),
73
+ color=alt.Color("label:N", legend=alt.Legend(symbolOpacity=1.0)),
74
+ opacity=alt.condition(
75
+ selection, alt.Opacity("fewshot:O"), alt.value(0.1)
76
+ ),
77
+ )
78
+ .properties(title=cfg.title(group_name, renorm), width=800, height=400)
79
+ .resolve_legend(color="independent")
80
+ .resolve_axis(y="independent", x="independent")
81
+ .add_params(selection)
82
+ .interactive()
83
+ )
84
+ figures.append(Figure(name=f"{group_name} {cfg.name}", chart=chart))
85
+
86
+ return figures
87
 
88
 
89
  class CdfPlotConfig(ABC):
90
  plot_type: str
91
  xlabel: str
92
  ylabel: str
93
+ name: str = ""
94
 
95
  @abstractmethod
96
  def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "CdfData":
 
108
 
109
 
110
  class CorrectProbCdfPlot(CdfPlotConfig):
111
+ name = "Correct Prob Perf Curve"
112
+
113
  def __init__(self):
114
  self.plot_type = "corr perf plot"
115
  self.xlabel = "Correct answer probability"
 
130
 
131
 
132
  class CorrIncorrDiffConfig(CdfPlotConfig):
133
+ name = "Corr-Incorr Gap Perf Curve"
134
+
135
  def __init__(self):
136
  self.plot_type = "corr-max(incor) perf plot"
137
  self.xlabel = "corr prob - max(incor prob)"
src/rbeval/plot/utils.py CHANGED
@@ -59,7 +59,7 @@ class CdfData:
59
 
60
  @classmethod
61
  def from_weights(
62
- cls, weights: np.ndarray, scores: np.ndarray, max_p=1000
63
  ) -> "CdfData":
64
  sort_perm = scores.argsort()
65
  base_weights = weights[sort_perm]
 
59
 
60
  @classmethod
61
  def from_weights(
62
+ cls, weights: np.ndarray, scores: np.ndarray, max_p=600
63
  ) -> "CdfData":
64
  sort_perm = scores.argsort()
65
  base_weights = weights[sort_perm]