PGC-AI-Chatbot / tests /test_source_attribution.py
Jacooo's picture
Deploy from GitHub: 71cec45
70752d0 verified
# -*- coding: utf-8 -*-
"""
Three-Tier Source Attribution Test Script
==========================================
Tests that AI responses correctly attribute their source using the right
trailing emoji indicators:
πŸ“š Tier 1 β€” Verified plant database (plants_database.json)
πŸ“– Tier 2 β€” Verified RAG document (knowledge_chunks, similarity >=85%)
⚠️ Tier 3 β€” AI-generated estimate (Cerebras LLM fallback)
Usage:
python tests/test_source_attribution.py
"""
import asyncio
import sys
from pathlib import Path
# Force UTF-8 output so emojis render correctly on all terminals
sys.stdout.reconfigure(encoding="utf-8")
# Add parent directory to path so imports work
sys.path.insert(0, str(Path(__file__).parent.parent))
MOCK_SENSORS = {"temp": 28.5, "rh": 70.0, "light": 15000}
# ─────────────────────────────────────────────────────────────────────────────
# Test Cases
# ─────────────────────────────────────────────────────────────────────────────
TEST_CASES = [
# --- Tier 1: Deterministic DB (plant in plants_database.json) ---
{
"id": "T1-A",
"description": "Known plant in DB - Indonesian query (tomat)",
"query": "berapa suhu ideal tomat fase vegetatif?",
"expected_emoji": "πŸ“š",
},
{
"id": "T1-B",
"description": "Known plant in DB - English query (watermelon)",
"query": "what are the parameters for growing watermelon seedling?",
"expected_emoji": "πŸ“š",
},
{
"id": "T1-C",
"description": "Known plant in DB - Indonesian (semangka)",
"query": "apa parameter pertumbuhan semangka?",
"expected_emoji": "πŸ“š",
},
{
"id": "T1-D",
"description": "Known plant in DB - lettuce germination",
"query": "suhu dan kelembaban untuk perkecambahan selada?",
"expected_emoji": "πŸ“š",
},
{
"id": "T1-E",
"description": "Chamber status - real-time sensor data",
"query": "berapa suhu chamber sekarang?",
"expected_emoji": "πŸ“š",
},
# --- Tier 3: LLM Fallback (plant NOT in DB) ---
{
"id": "T3-A",
"description": "Unknown plant - LLM fallback (durian)",
"query": "apa parameter pertumbuhan durian?",
"expected_emoji": "⚠️",
},
{
"id": "T3-B",
"description": "Unknown plant - LLM fallback (strawberry)",
"query": "what temperature does strawberry need to grow?",
"expected_emoji": "⚠️",
},
# --- General / Technical (always AI Generated) ---
{
"id": "T3-C",
"description": "General plant question (no specific plant)",
"query": "bagaimana cara mempercepat pertumbuhan tanaman secara umum?",
"expected_emoji": "⚠️",
},
{
"id": "T3-D",
"description": "Technical IoT question",
"query": "how does a DHT22 humidity sensor work?",
"expected_emoji": "⚠️",
},
{
"id": "T3-E",
"description": "General knowledge question",
"query": "what is photosynthesis?",
"expected_emoji": "⚠️",
},
]
# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────
def detect_source_emoji(response_text: str) -> str | None:
"""
Detect the trailing source attribution emoji in the response.
Checks the last 8 lines for Source attribution lines.
"""
lines = response_text.strip().split("\n")
for line in reversed(lines[-8:]):
if "πŸ“š" in line and "Source" in line:
return "πŸ“š"
if "πŸ“–" in line and "Source" in line:
return "πŸ“–"
if "⚠️" in line and "Source" in line:
return "⚠️"
# Fallback: scan full response for opener indicators
if "πŸ“š According" in response_text or "πŸ“š Source" in response_text:
return "πŸ“š"
if "πŸ“– Source" in response_text:
return "πŸ“–"
if "⚠️ Note" in response_text or "⚠️ Source" in response_text:
return "⚠️"
return None
def truncate(text: str, n: int = 400) -> str:
return text[:n] + "\n ...[truncated]" if len(text) > n else text
# ─────────────────────────────────────────────────────────────────────────────
# Main Test Runner
# ─────────────────────────────────────────────────────────────────────────────
async def run_tests():
from app.ai_engine import generate_context_aware_response
passed = 0
failed = 0
results = []
W = 72
print("\n" + "=" * W)
print(" THREE-TIER SOURCE ATTRIBUTION TEST SUITE")
print("=" * W)
print(f" πŸ“š Tier 1: Verified DB | πŸ“– Tier 2: Verified Doc (>=85%) | ⚠️ Tier 3: AI Generated")
print("=" * W)
print(f"{'ID':<8} {'EXPECTED':<12} {'GOT':<12} {'STATUS':<8} DESCRIPTION")
print("-" * W)
for case in TEST_CASES:
try:
result = await generate_context_aware_response(
query=case["query"],
sensors=MOCK_SENSORS,
)
response_text = result.get("response", "")
detected = detect_source_emoji(response_text)
ok = (detected == case["expected_emoji"])
status = "βœ… PASS" if ok else "❌ FAIL"
if ok:
passed += 1
else:
failed += 1
print(f"{case['id']:<8} {case['expected_emoji']:<12} {str(detected):<12} {status:<8} {case['description']}")
results.append({
"case": case,
"response": response_text,
"detected": detected,
"data_source": result.get("data_source", "unknown"),
"query_type": result.get("query_type", "unknown"),
"passed": ok,
})
except Exception as e:
failed += 1
print(f"{case['id']:<8} {'?':<12} {'ERROR':<12} ❌ FAIL {case['description']}")
print(f" Exception: {e}")
results.append({
"case": case,
"response": f"ERROR: {e}",
"detected": None,
"data_source": "error",
"query_type": "error",
"passed": False,
})
# ─── Detailed Response Printout ───────────────────────────────────────────
print("\n" + "=" * W)
print(" DETAILED RESPONSES")
print("=" * W)
for r in results:
case = r["case"]
icon = "βœ…" if r["passed"] else "❌"
print(f"\n{icon} [{case['id']}] {case['description']}")
print(f" Query : {case['query']}")
print(f" Query Type : {r['query_type']}")
print(f" Data Source: {r['data_source']}")
print(f" Detected : {r['detected']} (Expected: {case['expected_emoji']})")
print(f" Response :")
for line in truncate(r["response"], 500).split("\n"):
print(f" {line}")
# ─── Summary ─────────────────────────────────────────────────────────────
total = len(TEST_CASES)
print("\n" + "=" * W)
print(f" RESULTS: {passed}/{total} passed | {failed} failed")
if failed == 0:
print(" βœ… All source attribution tests PASSED!")
else:
print(" ❌ Some tests FAILED β€” check prompt attribution rules above.")
print("=" * W + "\n")
if __name__ == "__main__":
asyncio.run(run_tests())