Spaces:
Sleeping
Sleeping
File size: 8,718 Bytes
0304d75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | """
tests/test_feedback.py
======================
Unit tests for the RL feedback subsystem:
- database/feedback.py (FeedbackStore)
- utils/reward.py (compute_reward, compute_combined_reward, reward_to_weight)
- utils/llm_reward.py (get_llm_reward — offline/no-key path only)
"""
from __future__ import annotations
import pytest
from database.feedback import FeedbackStore
from utils.reward import compute_combined_reward, compute_reward, reward_to_weight
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def store() -> FeedbackStore:
"""Return an in-memory FeedbackStore (discarded after each test)."""
return FeedbackStore(":memory:")
# ---------------------------------------------------------------------------
# FeedbackStore — table creation & basic CRUD
# ---------------------------------------------------------------------------
class TestFeedbackStore:
def test_tables_created(self, store: FeedbackStore) -> None:
"""Tables must exist immediately after construction."""
cur = store._conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
)
names = {row["name"] for row in cur.fetchall()}
assert "feedback" in names
assert "experience" in names
def test_save_feedback_correct(self, store: FeedbackStore) -> None:
"""Saving a correct-prediction feedback inserts both feedback and experience rows."""
fid = store.save_feedback(
username="alice",
text="I am so stressed",
prediction=0.8,
user_feedback=1,
reward=1.0,
)
assert fid >= 1
# feedback table
fb_rows = store.get_all_feedback()
assert len(fb_rows) == 1
row = fb_rows[0]
assert row["username"] == "alice"
assert row["user_feedback"] == 1
assert row["reward"] == pytest.approx(1.0)
assert row["llm_reward"] is None
# experience table: label should equal round(0.8) = 1 (correct)
exp = store.get_experience_for_training(min_samples=1)
assert len(exp) == 1
assert exp[0]["label"] == 1
assert exp[0]["reward"] == pytest.approx(1.0)
def test_save_feedback_wrong(self, store: FeedbackStore) -> None:
"""Wrong-prediction feedback flips the label in the experience table."""
store.save_feedback(
username="bob",
text="Everything is fine",
prediction=0.8, # predicted stressed (class 1)
user_feedback=0, # but user says wrong → corrected label should be 0
reward=-1.0,
)
exp = store.get_experience_for_training(min_samples=1)
assert len(exp) == 1
assert exp[0]["label"] == 0
def test_save_feedback_with_llm_reward(self, store: FeedbackStore) -> None:
"""LLM reward is stored correctly."""
store.save_feedback(
username="carol",
text="text",
prediction=0.6,
user_feedback=1,
reward=1.0,
llm_reward=1,
)
rows = store.get_all_feedback()
assert rows[0]["llm_reward"] == 1
def test_get_user_stats_no_data(self, store: FeedbackStore) -> None:
stats = store.get_user_stats("nobody")
assert stats["total"] == 0
assert stats["mean_reward"] == pytest.approx(0.0)
def test_get_user_stats(self, store: FeedbackStore) -> None:
store.save_feedback("alice", "t1", 0.8, 1, 1.0)
store.save_feedback("alice", "t2", 0.8, 0, -1.0)
store.save_feedback("alice", "t3", 0.7, 1, 1.0)
stats = store.get_user_stats("alice")
assert stats["total"] == 3
assert stats["n_correct"] == 2
assert stats["n_wrong"] == 1
assert stats["mean_reward"] == pytest.approx(1 / 3, abs=0.01)
assert stats["accuracy_rate"] == pytest.approx(2 / 3, abs=0.01)
def test_min_samples_gate(self, store: FeedbackStore) -> None:
"""get_experience_for_training returns [] when below min_samples."""
store.save_feedback("alice", "text", 0.7, 1, 1.0)
assert store.get_experience_for_training(min_samples=5) == []
assert len(store.get_experience_for_training(min_samples=1)) == 1
def test_feedback_count(self, store: FeedbackStore) -> None:
assert store.get_feedback_count() == 0
store.save_feedback("u1", "t", 0.5, 1, 1.0)
store.save_feedback("u1", "t", 0.5, 0, -1.0)
store.save_feedback("u2", "t", 0.5, 1, 1.0)
assert store.get_feedback_count() == 3
assert store.get_feedback_count("u1") == 2
assert store.get_feedback_count("u2") == 1
def test_multiple_users_isolated(self, store: FeedbackStore) -> None:
"""User stats must be scoped to individual users."""
store.save_feedback("alice", "text", 0.9, 1, 1.0)
store.save_feedback("bob", "text", 0.2, 0, -1.0)
alice = store.get_user_stats("alice")
bob = store.get_user_stats("bob")
assert alice["total"] == 1
assert alice["n_correct"] == 1
assert bob["total"] == 1
assert bob["n_wrong"] == 1
def test_close(self, store: FeedbackStore) -> None:
"""close() must not raise."""
store.close()
# ---------------------------------------------------------------------------
# Reward functions
# ---------------------------------------------------------------------------
class TestComputeReward:
def test_correct_gives_positive(self) -> None:
assert compute_reward(1) == pytest.approx(1.0)
def test_wrong_gives_negative(self) -> None:
assert compute_reward(0) == pytest.approx(-1.0)
def test_correct_low_prediction(self) -> None:
assert compute_reward(1) == pytest.approx(1.0)
def test_wrong_low_prediction(self) -> None:
assert compute_reward(0) == pytest.approx(-1.0)
class TestComputeCombinedReward:
def test_no_llm_passes_through(self) -> None:
assert compute_combined_reward(1, None) == pytest.approx(1.0)
assert compute_combined_reward(0, None) == pytest.approx(-1.0)
def test_llm_agree_positive(self) -> None:
# Both user (+1) and LLM (+1) agree → average = +1
assert compute_combined_reward(1, 1) == pytest.approx(1.0)
def test_llm_agree_negative(self) -> None:
# Both user (-1) and LLM (-1) agree → average = -1
assert compute_combined_reward(0, -1) == pytest.approx(-1.0)
def test_llm_disagree_averages(self) -> None:
# User says correct (+1), LLM says wrong (-1) → average = 0.0
result = compute_combined_reward(1, -1)
assert result == pytest.approx(0.0)
def test_llm_partial_agreement(self) -> None:
# User says wrong (-1), LLM says correct (+1) → 0.0
result = compute_combined_reward(0, 1)
assert result == pytest.approx(0.0)
class TestRewardToWeight:
def test_nonzero_reward_gives_1_5(self) -> None:
assert reward_to_weight(1.0) == pytest.approx(1.5)
assert reward_to_weight(-1.0) == pytest.approx(1.5)
assert reward_to_weight(0.5) == pytest.approx(1.5)
def test_zero_reward_gives_1_0(self) -> None:
assert reward_to_weight(0.0) == pytest.approx(1.0)
# ---------------------------------------------------------------------------
# LLM reward (no-key / offline path)
# ---------------------------------------------------------------------------
class TestGetLlmRewardOffline:
"""These tests run without real API keys — the function must return None."""
def test_returns_none_without_keys(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
from utils.llm_reward import get_llm_reward
result = get_llm_reward("I am stressed", 0.8, provider="auto")
assert result is None
def test_openai_returns_none_without_key(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
from utils.llm_reward import get_llm_reward
result = get_llm_reward("text", 0.5, provider="openai")
assert result is None
def test_gemini_returns_none_without_key(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
from utils.llm_reward import get_llm_reward
result = get_llm_reward("text", 0.5, provider="gemini")
assert result is None
|