File size: 8,042 Bytes
c79d967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e181667
 
 
c79d967
 
 
 
 
 
 
 
 
e181667
 
 
c79d967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e181667
c79d967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e181667
c79d967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e181667
c79d967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e181667
c79d967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e181667
c79d967
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""
Telemetry: in-memory counters + HF Dataset persistence.

In-memory  β€” rolling deque(500), per-metric pass counters, avg latency.
             Resets on restart. Powers /metrics (live session stats).

Persistent β€” events flushed as JSONL shards to TELEMETRY_REPO HF dataset
             every FLUSH_EVERY queries. Powers /report (accumulated history).
             Falls back to in-memory if TELEMETRY_REPO is unset or write fails.

TELEMETRY_REPO is inferred from HF Spaces env vars:
  SPACE_AUTHOR_NAME + SPACE_REPO_NAME β†’ {author}/{repo}-telemetry
  Override with explicit TELEMETRY_REPO env var.
"""

import json
import logging
import os
import threading
from collections import defaultdict, deque
from datetime import UTC, datetime
from typing import Any

from grader import GradeReport

log = logging.getLogger(__name__)

BUFFER_SIZE = 500
FLUSH_EVERY = 20

_lock = threading.Lock()
_events: deque[dict[str, Any]] = deque(maxlen=BUFFER_SIZE)
_unflushed: list[dict[str, Any]] = []
_counters: dict[str, float] = defaultdict(float)

_space_author = os.environ.get("SPACE_AUTHOR_NAME", "")
_space_repo = os.environ.get("SPACE_REPO_NAME", "ai-response-validator")
TELEMETRY_REPO = os.environ.get(
    "TELEMETRY_REPO",
    f"{_space_author}/{_space_repo}-telemetry" if _space_author else "",
)

_METRICS = ["pii_leakage", "token_budget", "answer_relevancy", "faithfulness", "chain_terminology"]


def record(
    client: str,
    domain: str,
    query_len: int,
    latency_ms: dict[str, float],
    report: GradeReport,
    docs_retrieved: int,
    min_retrieval_score: float,
) -> None:
    """Record one query event. Thread-safe. Flushes to HF dataset in background."""
    event = {
        "ts": datetime.now(UTC).isoformat(),
        "client": client,
        "domain": domain,
        "query_len": query_len,
        "latency_ms": {k: round(v) for k, v in latency_ms.items()},
        "metrics": {r.metric: round(r.score, 4) for r in report.results},
        "metric_passed": {r.metric: r.passed for r in report.results},
        "overall_pass": report.overall,
        "docs_retrieved": docs_retrieved,
        "min_retrieval_score": round(min_retrieval_score, 4),
    }
    with _lock:
        _events.append(event)
        _unflushed.append(event)
        _counters["total"] += 1
        if report.overall:
            _counters["overall_pass"] += 1
        for r in report.results:
            _counters[f"{r.metric}_total"] += 1
            if r.passed:
                _counters[f"{r.metric}_pass"] += 1
        for stage, ms in latency_ms.items():
            _counters[f"lat_{stage}_sum"] += ms
            _counters[f"lat_{stage}_n"] += 1
        should_flush = len(_unflushed) >= FLUSH_EVERY

    if should_flush and TELEMETRY_REPO:
        threading.Thread(target=_flush, daemon=True).start()


def live_stats() -> dict[str, Any]:
    """In-memory aggregate for the current session (/metrics endpoint)."""
    with _lock:
        total = int(_counters.get("total", 0))
        if total == 0:
            return {"total_queries": 0, "message": "No queries recorded this session."}

        metric_stats = {}
        for m in _METRICS:
            mt = int(_counters.get(f"{m}_total", 0))
            mp = int(_counters.get(f"{m}_pass", 0))
            metric_stats[m] = {
                "pass_rate": round(mp / mt, 3) if mt else None,
                "pass_count": mp,
                "total": mt,
            }

        avg_latency = {}
        for stage in ("retrieve", "generate", "grade"):
            n = _counters.get(f"lat_{stage}_n", 0)
            if n:
                avg_latency[stage] = round(_counters[f"lat_{stage}_sum"] / n)

        return {
            "source": "in_memory",
            "total_queries": total,
            "overall_pass_rate": round(_counters.get("overall_pass", 0) / total, 3),
            "metrics": metric_stats,
            "avg_latency_ms": avg_latency,
            "events_in_buffer": len(_events),
            "telemetry_repo": TELEMETRY_REPO or None,
        }


def persistent_report() -> dict[str, Any]:
    """Aggregate from HF Dataset shards (/report endpoint). Falls back to live_stats."""
    if not TELEMETRY_REPO:
        log.info("TELEMETRY_REPO not set β€” report from in-memory only")
        return {"source": "in_memory", **live_stats()}

    try:
        from huggingface_hub import HfApi
        hf_token = os.environ.get("HF_TOKEN")
        api = HfApi(token=hf_token)

        files = api.list_repo_files(TELEMETRY_REPO, repo_type="dataset")
        shard_paths = [f for f in files if f.startswith("events/") and f.endswith(".jsonl")]
        if not shard_paths:
            return {"source": "hf_dataset", "repo": TELEMETRY_REPO,
                    "message": "No shards yet β€” data accumulates after first flush."}

        events = []
        for path in shard_paths:
            content = api.hf_hub_download(
                TELEMETRY_REPO, path, repo_type="dataset", token=hf_token,
            )
            with open(content) as f:
                for line in f:
                    if line.strip():
                        events.append(json.loads(line))

        if not events:
            return {"source": "hf_dataset", "repo": TELEMETRY_REPO, "total_events": 0}

        total = len(events)
        overall_pass = sum(1 for e in events if e.get("overall_pass"))

        metric_stats = {}
        for m in _METRICS:
            passed = sum(1 for e in events if e.get("metric_passed", {}).get(m))
            scores = [e["metrics"][m] for e in events if m in e.get("metrics", {})]
            metric_stats[m] = {
                "pass_rate": round(passed / total, 3),
                "avg_score": round(sum(scores) / len(scores), 3) if scores else None,
            }

        client_breakdown: dict[str, dict[str, int]] = defaultdict(lambda: {"total": 0, "pass": 0})
        for e in events:
            c = e.get("client", "unknown")
            client_breakdown[c]["total"] += 1
            if e.get("overall_pass"):
                client_breakdown[c]["pass"] += 1

        return {
            "source": "hf_dataset",
            "repo": TELEMETRY_REPO,
            "total_queries": total,
            "overall_pass_rate": round(overall_pass / total, 3),
            "first_event": min(e["ts"] for e in events),
            "last_event": max(e["ts"] for e in events),
            "metrics": metric_stats,
            "by_client": {
                c: {"total": v["total"], "pass_rate": round(v["pass"] / v["total"], 3)}
                for c, v in client_breakdown.items()
            },
            "shards_read": len(shard_paths),
        }

    except Exception as e:
        log.warning("HF Dataset report failed (%s) β€” falling back to in-memory", e)
        return {"source": "in_memory_fallback", **live_stats()}


def _flush() -> None:
    """Upload buffered events to HF Dataset as a JSONL shard. Runs in background thread."""
    with _lock:
        if not _unflushed:
            return
        batch = list(_unflushed)
        _unflushed.clear()

    try:
        from huggingface_hub import HfApi
        hf_token = os.environ.get("HF_TOKEN")
        if not hf_token:
            log.warning("HF_TOKEN not set β€” telemetry flush skipped")
            return

        api = HfApi(token=hf_token)
        try:
            api.create_repo(TELEMETRY_REPO, repo_type="dataset", exist_ok=True, private=False)
        except Exception:
            pass

        ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S%f")
        content = "\n".join(json.dumps(e) for e in batch).encode()
        api.upload_file(
            path_or_fileobj=content,
            path_in_repo=f"events/shard_{ts}.jsonl",
            repo_id=TELEMETRY_REPO,
            repo_type="dataset",
        )
        log.info("Flushed %d telemetry events to %s", len(batch), TELEMETRY_REPO)

    except Exception as e:
        log.warning("Telemetry flush failed: %s β€” events returned to buffer", e)
        with _lock:
            _unflushed.extend(batch)