| """ |
| Tests for DocMind NLI Grounding Gate. |
| |
| NOTE: These tests require the NLI model to be downloaded. |
| They will be slower on first run (~10s for model download). |
| Run with: python tests/test_grounding.py |
| """ |
|
|
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
|
|
|
|
| def test_entailment_high_score(): |
| """A sentence clearly stated in the premise should score high.""" |
| from pipeline.grounding import load_nli_model, score_sentence |
|
|
| tokenizer, model = load_nli_model.__wrapped__() |
|
|
| premise = "Full-time employees receive 30 days of paid vacation per year." |
| hypothesis = "Full-time employees receive 30 days of paid vacation." |
|
|
| score = score_sentence(tokenizer, model, premise, hypothesis) |
| print(f" Entailment score: {score:.4f} (expected >= 0.7)") |
| assert score >= 0.5, f"Expected high entailment, got {score}" |
|
|
|
|
| def test_contradiction_low_score(): |
| """A contradictory sentence should score low on entailment.""" |
| from pipeline.grounding import load_nli_model, score_sentence |
|
|
| tokenizer, model = load_nli_model.__wrapped__() |
|
|
| premise = "The product is available only in the United States." |
| hypothesis = "The product is available worldwide in all countries." |
|
|
| score = score_sentence(tokenizer, model, premise, hypothesis) |
| print(f" Contradiction score: {score:.4f} (expected <= 0.3)") |
| assert score <= 0.5, f"Expected low entailment for contradiction, got {score}" |
|
|
|
|
| def test_neutral_moderate_score(): |
| """An unrelated sentence should get low-to-moderate entailment.""" |
| from pipeline.grounding import load_nli_model, score_sentence |
|
|
| tokenizer, model = load_nli_model.__wrapped__() |
|
|
| premise = "The company was founded in 2005 in San Francisco." |
| hypothesis = "The weather in San Francisco is often foggy." |
|
|
| score = score_sentence(tokenizer, model, premise, hypothesis) |
| print(f" Neutral score: {score:.4f} (expected <= 0.4)") |
| assert score <= 0.5, f"Expected low entailment for neutral, got {score}" |
|
|
|
|
| if __name__ == "__main__": |
| tests = [ |
| test_entailment_high_score, |
| test_contradiction_low_score, |
| test_neutral_moderate_score, |
| ] |
|
|
| |
| |
| try: |
| import streamlit as st |
| except ImportError: |
| pass |
|
|
| passed = failed = 0 |
| for test_fn in tests: |
| name = test_fn.__name__ |
| try: |
| test_fn() |
| print(f" [PASS] {name}") |
| passed += 1 |
| except AssertionError as e: |
| print(f" [FAIL] {name}: {e}") |
| failed += 1 |
| except Exception as e: |
| print(f" [FAIL] {name}: {type(e).__name__}: {e}") |
| failed += 1 |
|
|
| print(f"\n{'='*40}") |
| print(f"Results: {passed} passed, {failed} failed") |
|
|