docmind / tests /test_grounding.py
AI Engineer
Initial commit for DocMind
6cca5b1
Raw
History Blame Contribute Delete
2.85 kB
"""
Tests for DocMind NLI Grounding Gate.
NOTE: These tests require the NLI model to be downloaded.
They will be slower on first run (~10s for model download).
Run with: python tests/test_grounding.py
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
def test_entailment_high_score():
"""A sentence clearly stated in the premise should score high."""
from pipeline.grounding import load_nli_model, score_sentence
tokenizer, model = load_nli_model.__wrapped__() # bypass st.cache
premise = "Full-time employees receive 30 days of paid vacation per year."
hypothesis = "Full-time employees receive 30 days of paid vacation."
score = score_sentence(tokenizer, model, premise, hypothesis)
print(f" Entailment score: {score:.4f} (expected >= 0.7)")
assert score >= 0.5, f"Expected high entailment, got {score}"
def test_contradiction_low_score():
"""A contradictory sentence should score low on entailment."""
from pipeline.grounding import load_nli_model, score_sentence
tokenizer, model = load_nli_model.__wrapped__()
premise = "The product is available only in the United States."
hypothesis = "The product is available worldwide in all countries."
score = score_sentence(tokenizer, model, premise, hypothesis)
print(f" Contradiction score: {score:.4f} (expected <= 0.3)")
assert score <= 0.5, f"Expected low entailment for contradiction, got {score}"
def test_neutral_moderate_score():
"""An unrelated sentence should get low-to-moderate entailment."""
from pipeline.grounding import load_nli_model, score_sentence
tokenizer, model = load_nli_model.__wrapped__()
premise = "The company was founded in 2005 in San Francisco."
hypothesis = "The weather in San Francisco is often foggy."
score = score_sentence(tokenizer, model, premise, hypothesis)
print(f" Neutral score: {score:.4f} (expected <= 0.4)")
assert score <= 0.5, f"Expected low entailment for neutral, got {score}"
if __name__ == "__main__":
tests = [
test_entailment_high_score,
test_contradiction_low_score,
test_neutral_moderate_score,
]
# Need to handle the st.cache_resource decorator
# We mock it away for standalone testing
try:
import streamlit as st
except ImportError:
pass
passed = failed = 0
for test_fn in tests:
name = test_fn.__name__
try:
test_fn()
print(f" [PASS] {name}")
passed += 1
except AssertionError as e:
print(f" [FAIL] {name}: {e}")
failed += 1
except Exception as e:
print(f" [FAIL] {name}: {type(e).__name__}: {e}")
failed += 1
print(f"\n{'='*40}")
print(f"Results: {passed} passed, {failed} failed")