IJNet-assistant / evaluate.py
Mohammad Haris
Deploy IJNet assistant
b87aca1
Raw
History Blame Contribute Delete
13.4 kB
"""
RAG Pipeline Evaluation (v2)
------------------------------
Comprehensive tests covering:
- Query classification accuracy
- Retrieval precision and recall
- Guardrail validation (on-topic / off-topic detection)
- End-to-end response quality (requires GROQ_API_KEY)
Usage:
python evaluate.py # classification + retrieval + guardrail tests
GROQ_API_KEY=xxx python evaluate.py # all tests including end-to-end
"""
import os
import sys
import json
from pathlib import Path
from datetime import datetime
sys.path.insert(0, str(Path(__file__).parent))
from src.ingest import load_knowledge_base, build_documents, get_embeddings, build_vector_store, load_vector_store
from src.retriever import HybridRetriever, classify_query
from src.chain import check_guardrails
# ---------------------------------------------------------------------------
# TEST CASES
# ---------------------------------------------------------------------------
RETRIEVAL_TESTS = [
{
"query": "What opportunities are available for investigative journalists in Africa?",
"expected_ids": ["opp-002", "opp-017"],
"description": "Region + topic filter β€” Africa investigative",
},
{
"query": "Find fellowships with deadlines in the next 30 days",
"expected_type": "fellowship",
"description": "Deadline + type filter β€” fellowships",
},
{
"query": "What resources does IJNet have on AI tools for journalists?",
"expected_ids": ["art-001", "opp-007", "opp-020"],
"description": "Topic search β€” AI tools",
},
{
"query": "Can you summarize the latest opportunities for product/design people in newsrooms?",
"expected_ids": ["art-003", "opp-016", "opp-015"],
"description": "Product/design role search",
},
{
"query": "Which IJNet newsletter should I subscribe to?",
"expected_ids": ["art-002"],
"description": "Newsletter-specific query",
},
{
"query": "What grants are available for data journalism?",
"expected_ids": ["opp-005", "opp-013"],
"description": "Grant type + data journalism topic",
},
{
"query": "Tell me about digital security for journalists",
"expected_ids": ["art-004"],
"description": "Article retrieval β€” digital security",
},
{
"query": "What training programs exist for journalists in the Middle East?",
"expected_ids": ["opp-007", "opp-012"],
"description": "Region filter β€” MENA",
},
{
"query": "Climate change reporting opportunities",
"expected_ids": ["opp-004", "opp-008"],
"description": "Topic β€” environment/climate",
},
{
"query": "What is IJNet?",
"expected_ids": ["ijnet-about"],
"description": "About IJNet query",
},
{
"query": "Opportunities for women journalists in Africa",
"expected_ids": ["opp-011"],
"description": "Women + Africa filter",
},
{
"query": "How can freelance journalists find funding?",
"expected_ids": ["art-006"],
"description": "Freelance funding article",
},
{
"query": "fact-checking training workshops",
"expected_ids": ["opp-017"],
"description": "Fact-checking topic",
},
{
"query": "press freedom fellowships",
"expected_ids": ["opp-019"],
"description": "Press freedom topic",
},
{
"query": "mobile journalism webinar",
"expected_ids": ["opp-012"],
"description": "MoJo / mobile journalism",
},
]
CLASSIFICATION_TESTS = [
("Find fellowships with deadlines in the next 30 days", "deadline_search", {"deadline_days": 30}),
("Opportunities for journalists in Africa", "region_search", {}),
("Which newsletter should I subscribe to?", "newsletter", {}),
("What is IJNet?", "about", {}),
("AI tools for newsrooms", "general", {}),
("Grants expiring within 60 days", "deadline_search", {"deadline_days": 60}),
("Training programs in the Middle East", "region_search", {}),
("Fellowships closing in the next 2 weeks", "deadline_search", {"deadline_days": 14}),
("Data journalism awards", "general", {}),
("What opportunities are there in South Asia?", "region_search", {}),
]
GUARDRAIL_TESTS = [
# (query, should_be_allowed)
("What fellowships are available for African journalists?", True),
("Tell me about AI tools for newsrooms", True),
("Which IJNet newsletter should I subscribe to?", True),
("Hello", True),
("Thanks for the help!", True),
("What grants exist?", True),
("help", True),
# Off-topic queries
("Write me a poem about the moon", False),
("What's the weather in New York?", False),
("How do I cook pasta carbonara?", False),
("Solve this math equation: 2x + 5 = 15", False),
("Translate this to French: hello world", False),
("Tell me a joke", False),
# Edge cases β€” should still be allowed (journalism-adjacent)
("How can journalists use AI?", True),
("media training opportunities", True),
("press freedom in Asia", True),
]
E2E_TESTS = [
{
"query": "What opportunities are available for investigative journalists in Africa?",
"must_contain": ["Africa", "investigat"],
"must_not_contain": ["I don't have information"],
"description": "Should find African investigative opportunities",
},
{
"query": "Which IJNet newsletter should I subscribe to?",
"must_contain": ["newsletter", "subscribe"],
"must_not_contain": ["I don't have information"],
"description": "Should describe newsletter options",
},
{
"query": "Write me a poem about the ocean",
"must_contain": ["journalism", "IJNet"],
"must_not_contain": ["ocean", "poem", "sea"],
"description": "Should reject off-topic and redirect",
},
{
"query": "What AI tools can journalists use?",
"must_contain": ["AI", "tool"],
"must_not_contain": ["I don't have information"],
"description": "Should discuss AI tools from the article",
},
{
"query": "Are there any grants for data journalism?",
"must_contain": ["data journalism", "grant"],
"must_not_contain": [],
"description": "Should find data journalism grants",
},
]
# ---------------------------------------------------------------------------
# TEST RUNNERS
# ---------------------------------------------------------------------------
def run_classification_tests():
"""Test query classification accuracy."""
print("\n" + "=" * 60)
print("1. QUERY CLASSIFICATION TESTS")
print("=" * 60)
passed = 0
for query, expected_intent, expected_filters in CLASSIFICATION_TESTS:
result = classify_query(query)
intent_match = result["intent"] == expected_intent
filter_match = True
for key, val in expected_filters.items():
if result["filters"].get(key) != val:
filter_match = False
status = "βœ…" if (intent_match and filter_match) else "❌"
print(f" {status} \"{query[:55]}...\"" if len(query) > 55 else f" {status} \"{query}\"")
if not (intent_match and filter_match):
print(f" Expected: {expected_intent}, Got: {result['intent']}")
if intent_match and filter_match:
passed += 1
total = len(CLASSIFICATION_TESTS)
print(f"\n Result: {passed}/{total} passed ({passed/total:.0%})")
return passed, total
def run_guardrail_tests():
"""Test guardrail accuracy."""
print("\n" + "=" * 60)
print("2. GUARDRAIL TESTS")
print("=" * 60)
passed = 0
for query, should_allow in GUARDRAIL_TESTS:
is_allowed, msg = check_guardrails(query)
correct = is_allowed == should_allow
status = "βœ…" if correct else "❌"
expected = "allow" if should_allow else "block"
actual = "allowed" if is_allowed else "blocked"
print(f" {status} [{expected}] \"{query[:50]}\" β†’ {actual}")
if correct:
passed += 1
total = len(GUARDRAIL_TESTS)
print(f"\n Result: {passed}/{total} passed ({passed/total:.0%})")
return passed, total
def run_retrieval_tests(retriever: HybridRetriever):
"""Test retrieval accuracy."""
print("\n" + "=" * 60)
print("3. RETRIEVAL TESTS")
print("=" * 60)
passed = 0
total = len(RETRIEVAL_TESTS)
for i, test in enumerate(RETRIEVAL_TESTS, 1):
query = test["query"]
expected_ids = test.get("expected_ids", [])
expected_type = test.get("expected_type", None)
results = retriever.retrieve(query)
retrieved_ids = [doc.metadata.get("doc_id", "") for doc in results]
test_passed = True
if expected_ids:
found = [eid for eid in expected_ids if eid in retrieved_ids]
recall = len(found) / len(expected_ids)
if recall < 0.5:
test_passed = False
status = "βœ…" if test_passed else "❌"
print(f" {status} {test['description']}")
if not test_passed:
print(f" Expected: {expected_ids}, Got: {retrieved_ids}")
if test_passed:
passed += 1
print(f"\n Result: {passed}/{total} passed ({passed/total:.0%})")
return passed, total
def run_e2e_tests(retriever: HybridRetriever):
"""Test full end-to-end response quality. Requires GROQ_API_KEY."""
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
print("\n" + "=" * 60)
print("4. END-TO-END TESTS (SKIPPED β€” set GROQ_API_KEY to run)")
print("=" * 60)
return 0, 0
print("\n" + "=" * 60)
print("4. END-TO-END TESTS")
print("=" * 60)
from src.chain import IJNetRAGChain
chain = IJNetRAGChain(retriever=retriever, groq_api_key=api_key)
passed = 0
for test in E2E_TESTS:
query = test["query"]
try:
result = chain.query(query)
answer = result["answer"].lower()
# Check must_contain
contains_ok = all(
term.lower() in answer for term in test["must_contain"]
)
# Check must_not_contain
not_contains_ok = all(
term.lower() not in answer for term in test["must_not_contain"]
)
test_passed = contains_ok and not_contains_ok
status = "βœ…" if test_passed else "❌"
print(f" {status} {test['description']}")
if not test_passed:
if not contains_ok:
missing = [t for t in test["must_contain"] if t.lower() not in answer]
print(f" Missing terms: {missing}")
if not not_contains_ok:
found_bad = [t for t in test["must_not_contain"] if t.lower() in answer]
print(f" Unwanted terms found: {found_bad}")
print(f" Response preview: {answer[:150]}...")
if test_passed:
passed += 1
except Exception as e:
print(f" ❌ {test['description']} β€” Error: {e}")
total = len(E2E_TESTS)
print(f"\n Result: {passed}/{total} passed ({passed/total:.0%})")
return passed, total
# ---------------------------------------------------------------------------
# MAIN
# ---------------------------------------------------------------------------
def main():
print("=" * 60)
print("IJNet RAG Pipeline β€” Evaluation Suite (v2)")
print("=" * 60)
# Initialize
print("\nInitializing pipeline...")
kb = load_knowledge_base("data/knowledge_base.json")
documents = build_documents(kb)
embeddings = get_embeddings()
index_path = "data/faiss_index"
if Path(index_path).exists():
vector_store = load_vector_store(index_path, embeddings)
else:
vector_store = build_vector_store(documents, embeddings, index_path)
retriever = HybridRetriever(
vector_store=vector_store,
documents=documents,
semantic_k=8,
bm25_k=8,
final_k=5,
)
print(f"Pipeline ready. {len(documents)} documents indexed.")
# Run all test suites
results = []
results.append(("Classification", *run_classification_tests()))
results.append(("Guardrails", *run_guardrail_tests()))
results.append(("Retrieval", *run_retrieval_tests(retriever)))
results.append(("End-to-End", *run_e2e_tests(retriever)))
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
total_passed = 0
total_tests = 0
for name, passed, total in results:
if total > 0:
print(f" {name:20s}: {passed}/{total} ({passed/total:.0%})")
total_passed += passed
total_tests += total
if total_tests > 0:
print(f" {'OVERALL':20s}: {total_passed}/{total_tests} ({total_passed/total_tests:.0%})")
print("=" * 60)
if __name__ == "__main__":
main()