File size: 2,133 Bytes
6bef416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from __future__ import annotations

import numpy as np
import pandas as pd
import streamlit as st

from src.charts import bubble_quality, heatmap
from src.ui import callout, section_title


class QualityMapViewMixin:
    """Quality maps across domain, difficulty, and risk slices."""

    def _page_quality_map(self) -> None:
        section_title("Quality Map", "Where quality is strong, unstable, or worth deeper review.")
        domain_diff = self._domain_difficulty_table(self.ctx.filtered_eval)
        left, right = st.columns(2, gap="large")
        with left:
            self._plot(heatmap(domain_diff, "difficulty", "domain", "correct_rate", "Correctness · domain × difficulty"), "quality_correctness")
        with right:
            self._plot(heatmap(domain_diff, "difficulty", "domain", "hallucination_rate", "Hallucination · domain × difficulty"), "quality_hallucination")
        self._plot(bubble_quality(self.ctx.risk_slices, "Risk slice map"), "quality_risk_bubble")
        if not self.ctx.demand_coverage.empty:
            callout(
                "info",
                "Coverage interpretation",
                "Positive demand-minus-corpus means the domain receives more evaluation demand than its corpus document share.",
            )

    @staticmethod
    def _domain_difficulty_table(df: pd.DataFrame) -> pd.DataFrame:
        needed = {"domain", "difficulty"}
        if df.empty or not needed.issubset(df.columns):
            return pd.DataFrame()
        aggs: dict[str, tuple[str, str]] = {"n": ("domain", "size")}
        if "is_correct" in df.columns:
            aggs["correct_rate"] = ("is_correct", "mean")
        if "hallucination_flag" in df.columns:
            aggs["hallucination_rate"] = ("hallucination_flag", "mean")
        if "recall_at_10" in df.columns:
            aggs["recall_at_10"] = ("recall_at_10", "mean")
        out = df.groupby(["domain", "difficulty"], dropna=False).agg(**aggs).reset_index()
        for col in ["correct_rate", "hallucination_rate", "recall_at_10"]:
            if col not in out.columns:
                out[col] = np.nan
        return out