Spaces:
Sleeping
Sleeping
| """ | |
| RAG Pipeline Evaluation (v2) | |
| ------------------------------ | |
| Comprehensive tests covering: | |
| - Query classification accuracy | |
| - Retrieval precision and recall | |
| - Guardrail validation (on-topic / off-topic detection) | |
| - End-to-end response quality (requires GROQ_API_KEY) | |
| Usage: | |
| python evaluate.py # classification + retrieval + guardrail tests | |
| GROQ_API_KEY=xxx python evaluate.py # all tests including end-to-end | |
| """ | |
| import os | |
| import sys | |
| import json | |
| from pathlib import Path | |
| from datetime import datetime | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from src.ingest import load_knowledge_base, build_documents, get_embeddings, build_vector_store, load_vector_store | |
| from src.retriever import HybridRetriever, classify_query | |
| from src.chain import check_guardrails | |
| # --------------------------------------------------------------------------- | |
| # TEST CASES | |
| # --------------------------------------------------------------------------- | |
| RETRIEVAL_TESTS = [ | |
| { | |
| "query": "What opportunities are available for investigative journalists in Africa?", | |
| "expected_ids": ["opp-002", "opp-017"], | |
| "description": "Region + topic filter β Africa investigative", | |
| }, | |
| { | |
| "query": "Find fellowships with deadlines in the next 30 days", | |
| "expected_type": "fellowship", | |
| "description": "Deadline + type filter β fellowships", | |
| }, | |
| { | |
| "query": "What resources does IJNet have on AI tools for journalists?", | |
| "expected_ids": ["art-001", "opp-007", "opp-020"], | |
| "description": "Topic search β AI tools", | |
| }, | |
| { | |
| "query": "Can you summarize the latest opportunities for product/design people in newsrooms?", | |
| "expected_ids": ["art-003", "opp-016", "opp-015"], | |
| "description": "Product/design role search", | |
| }, | |
| { | |
| "query": "Which IJNet newsletter should I subscribe to?", | |
| "expected_ids": ["art-002"], | |
| "description": "Newsletter-specific query", | |
| }, | |
| { | |
| "query": "What grants are available for data journalism?", | |
| "expected_ids": ["opp-005", "opp-013"], | |
| "description": "Grant type + data journalism topic", | |
| }, | |
| { | |
| "query": "Tell me about digital security for journalists", | |
| "expected_ids": ["art-004"], | |
| "description": "Article retrieval β digital security", | |
| }, | |
| { | |
| "query": "What training programs exist for journalists in the Middle East?", | |
| "expected_ids": ["opp-007", "opp-012"], | |
| "description": "Region filter β MENA", | |
| }, | |
| { | |
| "query": "Climate change reporting opportunities", | |
| "expected_ids": ["opp-004", "opp-008"], | |
| "description": "Topic β environment/climate", | |
| }, | |
| { | |
| "query": "What is IJNet?", | |
| "expected_ids": ["ijnet-about"], | |
| "description": "About IJNet query", | |
| }, | |
| { | |
| "query": "Opportunities for women journalists in Africa", | |
| "expected_ids": ["opp-011"], | |
| "description": "Women + Africa filter", | |
| }, | |
| { | |
| "query": "How can freelance journalists find funding?", | |
| "expected_ids": ["art-006"], | |
| "description": "Freelance funding article", | |
| }, | |
| { | |
| "query": "fact-checking training workshops", | |
| "expected_ids": ["opp-017"], | |
| "description": "Fact-checking topic", | |
| }, | |
| { | |
| "query": "press freedom fellowships", | |
| "expected_ids": ["opp-019"], | |
| "description": "Press freedom topic", | |
| }, | |
| { | |
| "query": "mobile journalism webinar", | |
| "expected_ids": ["opp-012"], | |
| "description": "MoJo / mobile journalism", | |
| }, | |
| ] | |
| CLASSIFICATION_TESTS = [ | |
| ("Find fellowships with deadlines in the next 30 days", "deadline_search", {"deadline_days": 30}), | |
| ("Opportunities for journalists in Africa", "region_search", {}), | |
| ("Which newsletter should I subscribe to?", "newsletter", {}), | |
| ("What is IJNet?", "about", {}), | |
| ("AI tools for newsrooms", "general", {}), | |
| ("Grants expiring within 60 days", "deadline_search", {"deadline_days": 60}), | |
| ("Training programs in the Middle East", "region_search", {}), | |
| ("Fellowships closing in the next 2 weeks", "deadline_search", {"deadline_days": 14}), | |
| ("Data journalism awards", "general", {}), | |
| ("What opportunities are there in South Asia?", "region_search", {}), | |
| ] | |
| GUARDRAIL_TESTS = [ | |
| # (query, should_be_allowed) | |
| ("What fellowships are available for African journalists?", True), | |
| ("Tell me about AI tools for newsrooms", True), | |
| ("Which IJNet newsletter should I subscribe to?", True), | |
| ("Hello", True), | |
| ("Thanks for the help!", True), | |
| ("What grants exist?", True), | |
| ("help", True), | |
| # Off-topic queries | |
| ("Write me a poem about the moon", False), | |
| ("What's the weather in New York?", False), | |
| ("How do I cook pasta carbonara?", False), | |
| ("Solve this math equation: 2x + 5 = 15", False), | |
| ("Translate this to French: hello world", False), | |
| ("Tell me a joke", False), | |
| # Edge cases β should still be allowed (journalism-adjacent) | |
| ("How can journalists use AI?", True), | |
| ("media training opportunities", True), | |
| ("press freedom in Asia", True), | |
| ] | |
| E2E_TESTS = [ | |
| { | |
| "query": "What opportunities are available for investigative journalists in Africa?", | |
| "must_contain": ["Africa", "investigat"], | |
| "must_not_contain": ["I don't have information"], | |
| "description": "Should find African investigative opportunities", | |
| }, | |
| { | |
| "query": "Which IJNet newsletter should I subscribe to?", | |
| "must_contain": ["newsletter", "subscribe"], | |
| "must_not_contain": ["I don't have information"], | |
| "description": "Should describe newsletter options", | |
| }, | |
| { | |
| "query": "Write me a poem about the ocean", | |
| "must_contain": ["journalism", "IJNet"], | |
| "must_not_contain": ["ocean", "poem", "sea"], | |
| "description": "Should reject off-topic and redirect", | |
| }, | |
| { | |
| "query": "What AI tools can journalists use?", | |
| "must_contain": ["AI", "tool"], | |
| "must_not_contain": ["I don't have information"], | |
| "description": "Should discuss AI tools from the article", | |
| }, | |
| { | |
| "query": "Are there any grants for data journalism?", | |
| "must_contain": ["data journalism", "grant"], | |
| "must_not_contain": [], | |
| "description": "Should find data journalism grants", | |
| }, | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # TEST RUNNERS | |
| # --------------------------------------------------------------------------- | |
| def run_classification_tests(): | |
| """Test query classification accuracy.""" | |
| print("\n" + "=" * 60) | |
| print("1. QUERY CLASSIFICATION TESTS") | |
| print("=" * 60) | |
| passed = 0 | |
| for query, expected_intent, expected_filters in CLASSIFICATION_TESTS: | |
| result = classify_query(query) | |
| intent_match = result["intent"] == expected_intent | |
| filter_match = True | |
| for key, val in expected_filters.items(): | |
| if result["filters"].get(key) != val: | |
| filter_match = False | |
| status = "β " if (intent_match and filter_match) else "β" | |
| print(f" {status} \"{query[:55]}...\"" if len(query) > 55 else f" {status} \"{query}\"") | |
| if not (intent_match and filter_match): | |
| print(f" Expected: {expected_intent}, Got: {result['intent']}") | |
| if intent_match and filter_match: | |
| passed += 1 | |
| total = len(CLASSIFICATION_TESTS) | |
| print(f"\n Result: {passed}/{total} passed ({passed/total:.0%})") | |
| return passed, total | |
| def run_guardrail_tests(): | |
| """Test guardrail accuracy.""" | |
| print("\n" + "=" * 60) | |
| print("2. GUARDRAIL TESTS") | |
| print("=" * 60) | |
| passed = 0 | |
| for query, should_allow in GUARDRAIL_TESTS: | |
| is_allowed, msg = check_guardrails(query) | |
| correct = is_allowed == should_allow | |
| status = "β " if correct else "β" | |
| expected = "allow" if should_allow else "block" | |
| actual = "allowed" if is_allowed else "blocked" | |
| print(f" {status} [{expected}] \"{query[:50]}\" β {actual}") | |
| if correct: | |
| passed += 1 | |
| total = len(GUARDRAIL_TESTS) | |
| print(f"\n Result: {passed}/{total} passed ({passed/total:.0%})") | |
| return passed, total | |
| def run_retrieval_tests(retriever: HybridRetriever): | |
| """Test retrieval accuracy.""" | |
| print("\n" + "=" * 60) | |
| print("3. RETRIEVAL TESTS") | |
| print("=" * 60) | |
| passed = 0 | |
| total = len(RETRIEVAL_TESTS) | |
| for i, test in enumerate(RETRIEVAL_TESTS, 1): | |
| query = test["query"] | |
| expected_ids = test.get("expected_ids", []) | |
| expected_type = test.get("expected_type", None) | |
| results = retriever.retrieve(query) | |
| retrieved_ids = [doc.metadata.get("doc_id", "") for doc in results] | |
| test_passed = True | |
| if expected_ids: | |
| found = [eid for eid in expected_ids if eid in retrieved_ids] | |
| recall = len(found) / len(expected_ids) | |
| if recall < 0.5: | |
| test_passed = False | |
| status = "β " if test_passed else "β" | |
| print(f" {status} {test['description']}") | |
| if not test_passed: | |
| print(f" Expected: {expected_ids}, Got: {retrieved_ids}") | |
| if test_passed: | |
| passed += 1 | |
| print(f"\n Result: {passed}/{total} passed ({passed/total:.0%})") | |
| return passed, total | |
| def run_e2e_tests(retriever: HybridRetriever): | |
| """Test full end-to-end response quality. Requires GROQ_API_KEY.""" | |
| api_key = os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| print("\n" + "=" * 60) | |
| print("4. END-TO-END TESTS (SKIPPED β set GROQ_API_KEY to run)") | |
| print("=" * 60) | |
| return 0, 0 | |
| print("\n" + "=" * 60) | |
| print("4. END-TO-END TESTS") | |
| print("=" * 60) | |
| from src.chain import IJNetRAGChain | |
| chain = IJNetRAGChain(retriever=retriever, groq_api_key=api_key) | |
| passed = 0 | |
| for test in E2E_TESTS: | |
| query = test["query"] | |
| try: | |
| result = chain.query(query) | |
| answer = result["answer"].lower() | |
| # Check must_contain | |
| contains_ok = all( | |
| term.lower() in answer for term in test["must_contain"] | |
| ) | |
| # Check must_not_contain | |
| not_contains_ok = all( | |
| term.lower() not in answer for term in test["must_not_contain"] | |
| ) | |
| test_passed = contains_ok and not_contains_ok | |
| status = "β " if test_passed else "β" | |
| print(f" {status} {test['description']}") | |
| if not test_passed: | |
| if not contains_ok: | |
| missing = [t for t in test["must_contain"] if t.lower() not in answer] | |
| print(f" Missing terms: {missing}") | |
| if not not_contains_ok: | |
| found_bad = [t for t in test["must_not_contain"] if t.lower() in answer] | |
| print(f" Unwanted terms found: {found_bad}") | |
| print(f" Response preview: {answer[:150]}...") | |
| if test_passed: | |
| passed += 1 | |
| except Exception as e: | |
| print(f" β {test['description']} β Error: {e}") | |
| total = len(E2E_TESTS) | |
| print(f"\n Result: {passed}/{total} passed ({passed/total:.0%})") | |
| return passed, total | |
| # --------------------------------------------------------------------------- | |
| # MAIN | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| print("=" * 60) | |
| print("IJNet RAG Pipeline β Evaluation Suite (v2)") | |
| print("=" * 60) | |
| # Initialize | |
| print("\nInitializing pipeline...") | |
| kb = load_knowledge_base("data/knowledge_base.json") | |
| documents = build_documents(kb) | |
| embeddings = get_embeddings() | |
| index_path = "data/faiss_index" | |
| if Path(index_path).exists(): | |
| vector_store = load_vector_store(index_path, embeddings) | |
| else: | |
| vector_store = build_vector_store(documents, embeddings, index_path) | |
| retriever = HybridRetriever( | |
| vector_store=vector_store, | |
| documents=documents, | |
| semantic_k=8, | |
| bm25_k=8, | |
| final_k=5, | |
| ) | |
| print(f"Pipeline ready. {len(documents)} documents indexed.") | |
| # Run all test suites | |
| results = [] | |
| results.append(("Classification", *run_classification_tests())) | |
| results.append(("Guardrails", *run_guardrail_tests())) | |
| results.append(("Retrieval", *run_retrieval_tests(retriever))) | |
| results.append(("End-to-End", *run_e2e_tests(retriever))) | |
| # Summary | |
| print("\n" + "=" * 60) | |
| print("SUMMARY") | |
| print("=" * 60) | |
| total_passed = 0 | |
| total_tests = 0 | |
| for name, passed, total in results: | |
| if total > 0: | |
| print(f" {name:20s}: {passed}/{total} ({passed/total:.0%})") | |
| total_passed += passed | |
| total_tests += total | |
| if total_tests > 0: | |
| print(f" {'OVERALL':20s}: {total_passed}/{total_tests} ({total_passed/total_tests:.0%})") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |