File size: 4,112 Bytes
656f91e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Capture user feedback on answers β€” the self-improving flywheel.

The MVP's core idea: a small shop owner can't audit SQL, but they CAN say
"that's right" or "no, that's wrong β€” it should be X". Because the app already
shows the exact query it ran, each thumbs-down + correction is a high-signal
*candidate training example*: the question, the SQL the model actually produced,
and what the human says is correct. These accumulate into exactly the data the
next training round needs β€” notably the multi-hop-join / error-recovery cases
the model is weakest on (see docs/guides/training_playbook.md).

Dep-light, stdlib only: one JSON object per line (JSONL), append-only, so it is
trivially greppable and feeds straight into scripts/generate_sft_data.py-style
curation later. No model or heavy deps imported here.
"""

from __future__ import annotations

import json
import time
from pathlib import Path
from typing import Any

# Runtime artifact (not source). Append-only; one record per feedback click.
DEFAULT_FEEDBACK_PATH = Path("data/feedback/feedback.jsonl")

_VERDICTS = ("up", "down")


def record_feedback(
    *,
    question: str,
    dataset: str,
    shown_sql: str,
    result: str,
    verdict: str,
    correction: str = "",
    path: str | Path | None = None,
) -> dict[str, Any]:
    """Append one feedback record as a JSONL line; return the stored record.

    Args:
        question: The user's natural-language question.
        dataset: Which dataset/DB it was asked against.
        shown_sql: The SQL the model actually ran (the audit surface).
        result: The answer/result the user is reacting to.
        verdict: "up" (correct) or "down" (wrong).
        correction: Free-text "what it should be" (the training signal on a
            down-vote; optional on an up-vote).
        path: Override the JSONL location (tests pass a tmp path).

    Returns:
        The record dict (also written to disk).

    Raises:
        ValueError: if ``verdict`` is not "up"/"down".
    """
    if verdict not in _VERDICTS:
        raise ValueError(f"verdict must be one of {_VERDICTS}, got {verdict!r}")

    correction = (correction or "").strip()
    record = {
        "ts": time.time(),
        "question": question,
        "dataset": dataset,
        "shown_sql": shown_sql,
        "result": result,
        "verdict": verdict,
        "correction": correction,
        # A down-vote WITH a correction is the gold: a labelled "the model said
        # X, the truth is Y" pair the next training round can learn from.
        "is_training_candidate": verdict == "down" and bool(correction),
    }

    out = Path(path) if path is not None else DEFAULT_FEEDBACK_PATH
    out.parent.mkdir(parents=True, exist_ok=True)
    with out.open("a", encoding="utf-8") as fh:
        fh.write(json.dumps(record, ensure_ascii=False) + "\n")
    return record


def load_feedback(path: str | Path | None = None) -> list[dict[str, Any]]:
    """Read all feedback records (for inspection / the demo's flywheel counter)."""
    src = Path(path) if path is not None else DEFAULT_FEEDBACK_PATH
    if not src.exists():
        return []
    records: list[dict[str, Any]] = []
    for line in src.read_text(encoding="utf-8").splitlines():
        line = line.strip()
        if not line:
            continue
        # Skip a malformed/truncated line (e.g. an interrupted write) rather than
        # let one bad record break feedback_summary() inside a UI click handler.
        try:
            records.append(json.loads(line))
        except json.JSONDecodeError:
            continue
    return records


def feedback_summary(path: str | Path | None = None) -> dict[str, int]:
    """Counts for a "flywheel" status line: total, πŸ‘, πŸ‘Ž, training candidates."""
    records = load_feedback(path)
    return {
        "total": len(records),
        "up": sum(1 for r in records if r.get("verdict") == "up"),
        "down": sum(1 for r in records if r.get("verdict") == "down"),
        "training_candidates": sum(
            1 for r in records if r.get("is_training_candidate")
        ),
    }