William Arnold
commited on
Commit
·
ca11d0f
1
Parent(s):
e3228e0
Ready for spaces
Browse files- README.md +10 -0
- app.py +3 -0
- pyproject.toml +6 -0
- requirements.txt +8 -0
- src/rbeval/dash.py +81 -0
- src/rbeval/dash/__main__.py +0 -0
- src/rbeval/plot/data.py +17 -13
- src/rbeval/plot/model_comp.py +32 -13
- src/rbeval/plot/score_cdf_altair.py +55 -35
- src/rbeval/plot/utils.py +1 -1
README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: rbeval
|
| 3 |
+
emoji: 💩
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: orange
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.25.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
app.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import rbeval.dash
|
| 2 |
+
|
| 3 |
+
rbeval.dash.main()
|
pyproject.toml
CHANGED
|
@@ -7,6 +7,12 @@ name = "rbeval"
|
|
| 7 |
requires-python = ">=3.8"
|
| 8 |
dynamic = ["version"]
|
| 9 |
dependencies = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
"seaborn>=0.13.2"
|
| 11 |
]
|
| 12 |
|
|
|
|
| 7 |
requires-python = ">=3.8"
|
| 8 |
dynamic = ["version"]
|
| 9 |
dependencies = [
|
| 10 |
+
"pandas>=2.2.2",
|
| 11 |
+
"matplotlib>=3.9.1",
|
| 12 |
+
"huggingface-hub>=0.24.2",
|
| 13 |
+
"tqdm>=4.66.4",
|
| 14 |
+
"numpy>=1.26.4",
|
| 15 |
+
"dacite>=1.8.1",
|
| 16 |
"seaborn>=0.13.2"
|
| 17 |
]
|
| 18 |
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas>=2.2.2
|
| 2 |
+
matplotlib>=3.9.1
|
| 3 |
+
huggingface-hub>=0.24.2
|
| 4 |
+
tqdm>=4.66.4
|
| 5 |
+
numpy>=1.26.4
|
| 6 |
+
dacite>=1.8.1
|
| 7 |
+
seaborn>=0.13.1
|
| 8 |
+
.
|
src/rbeval/dash.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
from rbeval.plot.data import EvalGroup, get_samples
|
| 7 |
+
from rbeval.plot.score_cdf_altair import (
|
| 8 |
+
plot_with_data,
|
| 9 |
+
get_plot_data,
|
| 10 |
+
plot_cfgs,
|
| 11 |
+
)
|
| 12 |
+
from rbeval.plot import model_comp
|
| 13 |
+
from huggingface_hub import snapshot_download
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@st.cache_resource
|
| 17 |
+
def cached_samples(dir: Path, name_filter: Optional[str]) -> List[EvalGroup]:
|
| 18 |
+
if not dir.exists():
|
| 19 |
+
dir = Path(snapshot_download("mli-will/rbeval"))
|
| 20 |
+
samples = get_samples(dir, name_filter)
|
| 21 |
+
return samples
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@st.cache_data
|
| 25 |
+
def cached_score_cdf(dir, name_filter):
|
| 26 |
+
samples = cached_samples(dir, name_filter)
|
| 27 |
+
cfgs = plot_cfgs()
|
| 28 |
+
data = [get_plot_data(cfg, samples) for cfg in cfgs]
|
| 29 |
+
return data, cfgs
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@st.cache_data
|
| 33 |
+
def cache_compare(dir, name_filter, base_name, compare_name):
|
| 34 |
+
samples = cached_samples(dir, name_filter)
|
| 35 |
+
grouped, base_name, comp_name = model_comp.get_scores(
|
| 36 |
+
samples, base_name + "$", compare_name + "$"
|
| 37 |
+
)
|
| 38 |
+
grouped_dict = {k: [vi.to_dict() for vi in v] for k, v in grouped.items()}
|
| 39 |
+
return grouped_dict, base_name, comp_name
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def main():
|
| 43 |
+
parser = argparse.ArgumentParser(description="rbeval dashboard")
|
| 44 |
+
parser.add_argument("eval_dir", type=str)
|
| 45 |
+
args, rest = parser.parse_known_args()
|
| 46 |
+
eval_dir = Path(args.eval_dir)
|
| 47 |
+
# Show all the models
|
| 48 |
+
score_cdf_data, cfgs = cached_score_cdf(eval_dir, None)
|
| 49 |
+
for data, cfg in zip(score_cdf_data, cfgs):
|
| 50 |
+
figs = plot_with_data(cfg, data)
|
| 51 |
+
with st.expander(cfg.name):
|
| 52 |
+
for fig in figs:
|
| 53 |
+
st.altair_chart(fig.chart)
|
| 54 |
+
|
| 55 |
+
model_names = set(
|
| 56 |
+
[
|
| 57 |
+
m.model_name
|
| 58 |
+
for group in cached_samples(eval_dir, None)
|
| 59 |
+
for m in group.model_evals
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
base_model = st.selectbox("Base model", model_names)
|
| 63 |
+
compare_model = st.selectbox("Compare model", model_names)
|
| 64 |
+
st.write(f"Comparing {base_model} with {compare_model}")
|
| 65 |
+
if base_model and compare_model:
|
| 66 |
+
if base_model == compare_model:
|
| 67 |
+
st.write("Base and compare models are the same")
|
| 68 |
+
return
|
| 69 |
+
grouped, base_name, comp_name = cache_compare(
|
| 70 |
+
eval_dir, None, base_model, compare_model
|
| 71 |
+
)
|
| 72 |
+
grouped = {
|
| 73 |
+
k: [model_comp.Scores.from_dict(vi) for vi in v] for k, v in grouped.items()
|
| 74 |
+
}
|
| 75 |
+
for fig in model_comp.get_figures(grouped, base_name, comp_name):
|
| 76 |
+
st.write(fig.name)
|
| 77 |
+
st.altair_chart(fig.chart)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
main()
|
src/rbeval/dash/__main__.py
DELETED
|
File without changes
|
src/rbeval/plot/data.py
CHANGED
|
@@ -7,6 +7,7 @@ from typing import Dict, List, Optional
|
|
| 7 |
from collections import defaultdict
|
| 8 |
import altair as alt
|
| 9 |
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
from tqdm import tqdm
|
| 12 |
|
|
@@ -26,18 +27,21 @@ def get_samples(inp: Path, name_filter: Optional[str]) -> List["EvalGroup"]:
|
|
| 26 |
print(f"Skipping spec {spec_file.stem}")
|
| 27 |
continue
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
with open(samples_file, "r") as f:
|
| 42 |
if samples_file.suffix == ".jsonl":
|
| 43 |
docs = [json.loads(s) for s in f.readlines()]
|
|
@@ -57,8 +61,8 @@ def get_samples(inp: Path, name_filter: Optional[str]) -> List["EvalGroup"]:
|
|
| 57 |
cor_logprobs=np.array(cor_logprobs),
|
| 58 |
inc_logprobs=np.array(inc_logprobs),
|
| 59 |
)
|
| 60 |
-
np.save(str(cache_file), asdict(eval)) # type: ignore
|
| 61 |
model_eval.evals.append(eval)
|
|
|
|
| 62 |
|
| 63 |
return list(groups.values())
|
| 64 |
|
|
|
|
| 7 |
from collections import defaultdict
|
| 8 |
import altair as alt
|
| 9 |
|
| 10 |
+
from dacite import from_dict
|
| 11 |
import numpy as np
|
| 12 |
from tqdm import tqdm
|
| 13 |
|
|
|
|
| 27 |
print(f"Skipping spec {spec_file.stem}")
|
| 28 |
continue
|
| 29 |
|
| 30 |
+
group_cache_file = Path(
|
| 31 |
+
spec_file.with_stem(spec_file.stem + "_group_cache")
|
| 32 |
+
).with_suffix(".npy")
|
| 33 |
+
if group_cache_file.exists():
|
| 34 |
+
res_dict = np.load(str(group_cache_file), allow_pickle=True).item()
|
| 35 |
+
group = from_dict(data_class=EvalGroup, data=res_dict)
|
| 36 |
+
groups[group.name] = group
|
| 37 |
+
continue
|
| 38 |
+
else:
|
| 39 |
+
group = groups.setdefault(spec.group, EvalGroup(name=spec.group))
|
| 40 |
+
model_eval = ModelEval(eval_spec=spec)
|
| 41 |
+
group.model_evals.append(model_eval)
|
| 42 |
+
for samples_file in (spec_file.parent / spec_file.stem).glob(
|
| 43 |
+
"**/samples_*.json*"
|
| 44 |
+
):
|
| 45 |
with open(samples_file, "r") as f:
|
| 46 |
if samples_file.suffix == ".jsonl":
|
| 47 |
docs = [json.loads(s) for s in f.readlines()]
|
|
|
|
| 61 |
cor_logprobs=np.array(cor_logprobs),
|
| 62 |
inc_logprobs=np.array(inc_logprobs),
|
| 63 |
)
|
|
|
|
| 64 |
model_eval.evals.append(eval)
|
| 65 |
+
np.save(str(group_cache_file), asdict(group)) # type: ignore
|
| 66 |
|
| 67 |
return list(groups.values())
|
| 68 |
|
src/rbeval/plot/model_comp.py
CHANGED
|
@@ -2,7 +2,7 @@ import argparse
|
|
| 2 |
import altair as alt
|
| 3 |
import pandas as pd
|
| 4 |
from collections import defaultdict
|
| 5 |
-
from dataclasses import dataclass, field
|
| 6 |
import itertools
|
| 7 |
from typing import Dict, List, Optional
|
| 8 |
import warnings
|
|
@@ -20,21 +20,18 @@ class Scores:
|
|
| 20 |
cor_minus_inc_samples: List[np.ndarray] = field(default_factory=list)
|
| 21 |
cor_samples: List[np.ndarray] = field(default_factory=list)
|
| 22 |
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
args = parser.parse_args(rem_args)
|
| 29 |
-
base_name_filt: Optional[str] = args.base
|
| 30 |
-
comp_name_filt: Optional[str] = args.compare
|
| 31 |
|
| 32 |
-
if base_name_filt is None or comp_name_filt is None:
|
| 33 |
-
warnings.warn(
|
| 34 |
-
"Skipping model comparison plot, need to specify base and compare"
|
| 35 |
-
)
|
| 36 |
-
return []
|
| 37 |
|
|
|
|
|
|
|
|
|
|
| 38 |
bases: List[ModelEval] = list(
|
| 39 |
itertools.chain.from_iterable(
|
| 40 |
g.collect_with_name(base_name_filt) for g in samples
|
|
@@ -107,6 +104,10 @@ def model_comparer(samples: List[EvalGroup], rem_args: List[str]) -> List[Figure
|
|
| 107 |
for title, scores in scores_by_mask.items():
|
| 108 |
grouped[title].append(scores)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
cmp_name = f"{base_name} to {comp_name}"
|
| 111 |
return [
|
| 112 |
Figure(name=f"{cmp_name} prob diff perf curves", chart=plot_diff_cdf(grouped)),
|
|
@@ -115,6 +116,24 @@ def model_comparer(samples: List[EvalGroup], rem_args: List[str]) -> List[Figure
|
|
| 115 |
]
|
| 116 |
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
def plot_diff_cdf(grouped: Dict[str, List[Scores]]) -> alt.HConcatChart:
|
| 119 |
charts = []
|
| 120 |
for title, score_list in grouped.items():
|
|
|
|
| 2 |
import altair as alt
|
| 3 |
import pandas as pd
|
| 4 |
from collections import defaultdict
|
| 5 |
+
from dataclasses import asdict, dataclass, field
|
| 6 |
import itertools
|
| 7 |
from typing import Dict, List, Optional
|
| 8 |
import warnings
|
|
|
|
| 20 |
cor_minus_inc_samples: List[np.ndarray] = field(default_factory=list)
|
| 21 |
cor_samples: List[np.ndarray] = field(default_factory=list)
|
| 22 |
|
| 23 |
+
def to_dict(self):
|
| 24 |
+
return asdict(self)
|
| 25 |
|
| 26 |
+
@classmethod
|
| 27 |
+
def from_dict(cls, d: dict):
|
| 28 |
+
d["spec"] = EvalSpec(**d["spec"])
|
| 29 |
+
return cls(**d)
|
|
|
|
|
|
|
|
|
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
def get_scores(
|
| 33 |
+
samples: List[EvalGroup], base_name_filt: str, comp_name_filt: str
|
| 34 |
+
) -> tuple[Dict[str, List[Scores]], str, str]:
|
| 35 |
bases: List[ModelEval] = list(
|
| 36 |
itertools.chain.from_iterable(
|
| 37 |
g.collect_with_name(base_name_filt) for g in samples
|
|
|
|
| 104 |
for title, scores in scores_by_mask.items():
|
| 105 |
grouped[title].append(scores)
|
| 106 |
|
| 107 |
+
return grouped, base_name, comp_name
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_figures(grouped: Dict[str, List[Scores]], base_name, comp_name) -> List[Figure]:
|
| 111 |
cmp_name = f"{base_name} to {comp_name}"
|
| 112 |
return [
|
| 113 |
Figure(name=f"{cmp_name} prob diff perf curves", chart=plot_diff_cdf(grouped)),
|
|
|
|
| 116 |
]
|
| 117 |
|
| 118 |
|
| 119 |
+
def model_comparer(samples: List[EvalGroup], rem_args: List[str]) -> List[Figure]:
|
| 120 |
+
parser = argparse.ArgumentParser()
|
| 121 |
+
parser.add_argument("--base", type=str)
|
| 122 |
+
parser.add_argument("--compare", type=str)
|
| 123 |
+
args = parser.parse_args(rem_args)
|
| 124 |
+
base_name_filt: Optional[str] = args.base
|
| 125 |
+
comp_name_filt: Optional[str] = args.compare
|
| 126 |
+
|
| 127 |
+
if base_name_filt is None or comp_name_filt is None:
|
| 128 |
+
warnings.warn(
|
| 129 |
+
"Skipping model comparison plot, need to specify base and compare"
|
| 130 |
+
)
|
| 131 |
+
return []
|
| 132 |
+
|
| 133 |
+
grouped, base_name, comp_name = get_scores(samples, base_name_filt, comp_name_filt)
|
| 134 |
+
return get_figures(grouped, base_name, comp_name)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
def plot_diff_cdf(grouped: Dict[str, List[Scores]]) -> alt.HConcatChart:
|
| 138 |
charts = []
|
| 139 |
for title, score_list in grouped.items():
|
src/rbeval/plot/score_cdf_altair.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from typing import List
|
| 2 |
|
| 3 |
from rbeval.plot.data import Eval, EvalGroup, Figure
|
|
@@ -9,25 +10,31 @@ import pandas as pd
|
|
| 9 |
from rbeval.plot.utils import CdfData, renormed
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
|
| 13 |
return [
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
),
|
| 18 |
-
Figure(
|
| 19 |
-
name="Corr-Incorr Gap Perf Curve",
|
| 20 |
-
chart=plot_with_config(CorrIncorrDiffConfig(), samples),
|
| 21 |
-
),
|
| 22 |
]
|
| 23 |
|
| 24 |
|
| 25 |
-
def
|
| 26 |
cfg: "CdfPlotConfig",
|
| 27 |
samples: List[EvalGroup],
|
| 28 |
-
) ->
|
| 29 |
-
|
| 30 |
for renorm in [True, False]:
|
|
|
|
| 31 |
for group in samples:
|
| 32 |
dfs = []
|
| 33 |
for m in group.model_evals:
|
|
@@ -38,43 +45,52 @@ def plot_with_config(
|
|
| 38 |
"x": cdf.scores,
|
| 39 |
"y": cdf.cdf_p,
|
| 40 |
"label": m.model_name,
|
|
|
|
| 41 |
"renorm": renorm,
|
| 42 |
"fewshot": spec.fewshot,
|
| 43 |
}
|
| 44 |
)
|
| 45 |
dfs.append(df)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
selection = alt.selection_point(fields=["label"], bind="legend")
|
| 49 |
-
charts = []
|
| 50 |
-
for group, df in zip(samples, group_dfs):
|
| 51 |
-
chart = (
|
| 52 |
-
alt.Chart(df)
|
| 53 |
-
.mark_line()
|
| 54 |
-
.encode(
|
| 55 |
-
x=alt.X("x:Q", title=cfg.xlabel),
|
| 56 |
-
y=alt.Y("y:Q", title=cfg.ylabel),
|
| 57 |
-
color=alt.Color("label:N", legend=alt.Legend(symbolOpacity=1.0)),
|
| 58 |
-
opacity=alt.condition(
|
| 59 |
-
selection, alt.Opacity("fewshot:O"), alt.value(0.1)
|
| 60 |
-
),
|
| 61 |
-
)
|
| 62 |
-
.properties(title=cfg.title(group.name, renorm))
|
| 63 |
-
.resolve_legend(color="independent")
|
| 64 |
-
)
|
| 65 |
|
| 66 |
-
charts.append(chart)
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
class CdfPlotConfig(ABC):
|
| 75 |
plot_type: str
|
| 76 |
xlabel: str
|
| 77 |
ylabel: str
|
|
|
|
| 78 |
|
| 79 |
@abstractmethod
|
| 80 |
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "CdfData":
|
|
@@ -92,6 +108,8 @@ class CdfPlotConfig(ABC):
|
|
| 92 |
|
| 93 |
|
| 94 |
class CorrectProbCdfPlot(CdfPlotConfig):
|
|
|
|
|
|
|
| 95 |
def __init__(self):
|
| 96 |
self.plot_type = "corr perf plot"
|
| 97 |
self.xlabel = "Correct answer probability"
|
|
@@ -112,6 +130,8 @@ class CorrectProbCdfPlot(CdfPlotConfig):
|
|
| 112 |
|
| 113 |
|
| 114 |
class CorrIncorrDiffConfig(CdfPlotConfig):
|
|
|
|
|
|
|
| 115 |
def __init__(self):
|
| 116 |
self.plot_type = "corr-max(incor) perf plot"
|
| 117 |
self.xlabel = "corr prob - max(incor prob)"
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
from typing import List
|
| 3 |
|
| 4 |
from rbeval.plot.data import Eval, EvalGroup, Figure
|
|
|
|
| 10 |
from rbeval.plot.utils import CdfData, renormed
|
| 11 |
|
| 12 |
|
| 13 |
+
@dataclass
|
| 14 |
+
class PlotData:
|
| 15 |
+
renorm: List[pd.DataFrame] = field(default_factory=list)
|
| 16 |
+
norenorm: List[pd.DataFrame] = field(default_factory=list)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def plot_cfgs():
|
| 20 |
+
return [CorrectProbCdfPlot(), CorrIncorrDiffConfig()]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
def score_cdf(samples: List[EvalGroup], args: List[str]) -> List[Figure]:
|
| 24 |
return [
|
| 25 |
+
a
|
| 26 |
+
for cfg in plot_cfgs()
|
| 27 |
+
for a in plot_with_data(cfg, get_plot_data(cfg, samples))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
]
|
| 29 |
|
| 30 |
|
| 31 |
+
def get_plot_data(
|
| 32 |
cfg: "CdfPlotConfig",
|
| 33 |
samples: List[EvalGroup],
|
| 34 |
+
) -> PlotData:
|
| 35 |
+
data = PlotData()
|
| 36 |
for renorm in [True, False]:
|
| 37 |
+
gfs = data.renorm if renorm else data.norenorm
|
| 38 |
for group in samples:
|
| 39 |
dfs = []
|
| 40 |
for m in group.model_evals:
|
|
|
|
| 45 |
"x": cdf.scores,
|
| 46 |
"y": cdf.cdf_p,
|
| 47 |
"label": m.model_name,
|
| 48 |
+
"group": group.name,
|
| 49 |
"renorm": renorm,
|
| 50 |
"fewshot": spec.fewshot,
|
| 51 |
}
|
| 52 |
)
|
| 53 |
dfs.append(df)
|
| 54 |
+
gfs.append(pd.concat(dfs))
|
| 55 |
+
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
|
|
|
| 57 |
|
| 58 |
+
def plot_with_data(
|
| 59 |
+
cfg: "CdfPlotConfig",
|
| 60 |
+
data: PlotData,
|
| 61 |
+
) -> List[Figure]:
|
| 62 |
+
figures = []
|
| 63 |
+
for renorm, group_dfs in zip([True, False], [data.renorm, data.norenorm]):
|
| 64 |
+
for df in group_dfs:
|
| 65 |
+
group_name = df["group"].iloc[0]
|
| 66 |
+
selection = alt.selection_point(fields=["label"], bind="legend")
|
| 67 |
+
chart = (
|
| 68 |
+
alt.Chart(df)
|
| 69 |
+
.mark_line()
|
| 70 |
+
.encode(
|
| 71 |
+
x=alt.X("x:Q", title=cfg.xlabel),
|
| 72 |
+
y=alt.Y("y:Q", title=cfg.ylabel),
|
| 73 |
+
color=alt.Color("label:N", legend=alt.Legend(symbolOpacity=1.0)),
|
| 74 |
+
opacity=alt.condition(
|
| 75 |
+
selection, alt.Opacity("fewshot:O"), alt.value(0.1)
|
| 76 |
+
),
|
| 77 |
+
)
|
| 78 |
+
.properties(title=cfg.title(group_name, renorm), width=800, height=400)
|
| 79 |
+
.resolve_legend(color="independent")
|
| 80 |
+
.resolve_axis(y="independent", x="independent")
|
| 81 |
+
.add_params(selection)
|
| 82 |
+
.interactive()
|
| 83 |
+
)
|
| 84 |
+
figures.append(Figure(name=f"{group_name} {cfg.name}", chart=chart))
|
| 85 |
+
|
| 86 |
+
return figures
|
| 87 |
|
| 88 |
|
| 89 |
class CdfPlotConfig(ABC):
|
| 90 |
plot_type: str
|
| 91 |
xlabel: str
|
| 92 |
ylabel: str
|
| 93 |
+
name: str = ""
|
| 94 |
|
| 95 |
@abstractmethod
|
| 96 |
def get_cdf(self, evals: List[Eval], prob_renorm: bool) -> "CdfData":
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
class CorrectProbCdfPlot(CdfPlotConfig):
|
| 111 |
+
name = "Correct Prob Perf Curve"
|
| 112 |
+
|
| 113 |
def __init__(self):
|
| 114 |
self.plot_type = "corr perf plot"
|
| 115 |
self.xlabel = "Correct answer probability"
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
class CorrIncorrDiffConfig(CdfPlotConfig):
|
| 133 |
+
name = "Corr-Incorr Gap Perf Curve"
|
| 134 |
+
|
| 135 |
def __init__(self):
|
| 136 |
self.plot_type = "corr-max(incor) perf plot"
|
| 137 |
self.xlabel = "corr prob - max(incor prob)"
|
src/rbeval/plot/utils.py
CHANGED
|
@@ -59,7 +59,7 @@ class CdfData:
|
|
| 59 |
|
| 60 |
@classmethod
|
| 61 |
def from_weights(
|
| 62 |
-
cls, weights: np.ndarray, scores: np.ndarray, max_p=
|
| 63 |
) -> "CdfData":
|
| 64 |
sort_perm = scores.argsort()
|
| 65 |
base_weights = weights[sort_perm]
|
|
|
|
| 59 |
|
| 60 |
@classmethod
|
| 61 |
def from_weights(
|
| 62 |
+
cls, weights: np.ndarray, scores: np.ndarray, max_p=600
|
| 63 |
) -> "CdfData":
|
| 64 |
sort_perm = scores.argsort()
|
| 65 |
base_weights = weights[sort_perm]
|