ask-the-web-agent / tests /test_feedback.py
debashis2007's picture
Upload folder using huggingface_hub
75bea1c verified
"""Unit tests for feedback module."""
import pytest
from src.feedback.evaluator import QualityEvaluator, QualityScore
from src.feedback.gaps import GapIdentifier, InformationGap
from src.feedback.refinement import RefinementStrategy, RefinementAction
class TestQualityEvaluator:
"""Tests for quality evaluator."""
def test_evaluate_good_response(self):
"""Test evaluating a good response."""
evaluator = QualityEvaluator(min_quality_threshold=0.6)
query = "What is machine learning?"
answer = """
Machine learning is a subset of artificial intelligence that enables
computers to learn from data without being explicitly programmed.
It uses algorithms to identify patterns in data, because these patterns
allow the system to make predictions or decisions.
"""
sources = [
{"title": "ML Guide", "url": "https://example.edu/ml"},
{"title": "AI Overview", "url": "https://ai.gov/overview"},
]
score = evaluator.evaluate(query, answer, sources)
assert score.overall >= 0.5
assert score.relevance > 0
assert score.completeness > 0
def test_evaluate_poor_response(self):
"""Test evaluating a poor response."""
evaluator = QualityEvaluator()
query = "What is the capital of France?"
answer = "Hello"
score = evaluator.evaluate(query, answer)
assert score.overall < 0.6
assert score.completeness < 0.5
def test_is_acceptable(self):
"""Test acceptability check."""
evaluator = QualityEvaluator(min_quality_threshold=0.6)
good_score = QualityScore(
relevance=0.8, completeness=0.7, accuracy=0.6,
clarity=0.8, sourcing=0.5, overall=0.68,
feedback=["Good"]
)
poor_score = QualityScore(
relevance=0.3, completeness=0.2, accuracy=0.4,
clarity=0.5, sourcing=0.2, overall=0.32,
feedback=["Poor"]
)
assert evaluator.is_acceptable(good_score) is True
assert evaluator.is_acceptable(poor_score) is False
class TestGapIdentifier:
"""Tests for gap identifier."""
def test_identify_missing_why(self):
"""Test identifying missing 'why' explanation."""
identifier = GapIdentifier()
query = "Why is the sky blue?"
answer = "The sky is blue. It appears blue during the day."
gaps = identifier.identify_gaps(query, answer)
# Should identify missing 'why' explanation
why_gaps = [g for g in gaps if "why" in g.description.lower()]
assert len(why_gaps) > 0
def test_identify_no_sources(self):
"""Test identifying unsourced claims."""
identifier = GapIdentifier()
query = "What is AI?"
answer = "AI is artificial intelligence."
gaps = identifier.identify_gaps(query, answer, sources=None)
unverified_gaps = [g for g in gaps if g.gap_type == "unverified"]
assert len(unverified_gaps) > 0
def test_identify_uncertainty(self):
"""Test identifying uncertainty language."""
identifier = GapIdentifier()
query = "When was Python created?"
answer = "I'm not sure, but Python might be from the 1990s."
gaps = identifier.identify_gaps(query, answer)
unclear_gaps = [g for g in gaps if g.gap_type == "unclear"]
assert len(unclear_gaps) > 0
def test_prioritize_gaps(self):
"""Test gap prioritization."""
identifier = GapIdentifier()
gaps = [
InformationGap("Low issue", "unclear", "low"),
InformationGap("High issue", "missing_fact", "high"),
InformationGap("Medium issue", "unverified", "medium"),
]
prioritized = identifier.prioritize_gaps(gaps)
assert prioritized[0].severity == "high"
assert prioritized[-1].severity == "low"
class TestRefinementStrategy:
"""Tests for refinement strategy."""
def test_analyze_low_quality(self):
"""Test analyzing low quality response."""
strategy = RefinementStrategy(max_iterations=3)
gaps = [
InformationGap(
"Missing explanation",
"missing_fact",
"high",
suggested_search="query more info"
),
]
actions = strategy.analyze(
query="Test query",
answer="Test answer",
gaps=gaps,
quality_score=0.4,
)
assert len(actions) > 0
assert any(a.action_type == "search" for a in actions)
def test_analyze_high_quality(self):
"""Test analyzing high quality response."""
strategy = RefinementStrategy()
actions = strategy.analyze(
query="Test query",
answer="Comprehensive answer",
gaps=[],
quality_score=0.9,
)
# High quality should need minimal refinement
assert len(actions) == 0
def test_should_continue(self):
"""Test continuation logic."""
strategy = RefinementStrategy(max_iterations=3)
# Should continue for low quality
assert strategy.should_continue(0.5) is True
# Should not continue for high quality
assert strategy.should_continue(0.8) is False
# Should not continue after max iterations
strategy._current_iteration = 3
assert strategy.should_continue(0.4) is False
def test_iteration_tracking(self):
"""Test iteration counting."""
strategy = RefinementStrategy()
assert strategy.get_iteration() == 0
strategy.increment_iteration()
assert strategy.get_iteration() == 1
strategy.increment_iteration()
assert strategy.get_iteration() == 2
strategy.reset()
assert strategy.get_iteration() == 0
def test_create_refined_query(self):
"""Test refined query creation."""
strategy = RefinementStrategy()
action = RefinementAction(
action_type="search",
description="Search for more",
parameters={"search_query": "specific search query"},
priority=1,
)
refined = strategy.create_refined_query("original query", action)
assert refined == "specific search query"
def test_merge_answers(self):
"""Test answer merging."""
strategy = RefinementStrategy()
action = RefinementAction(
action_type="expand",
description="Expand answer",
parameters={},
priority=1,
)
merged = strategy.merge_answers(
"Original answer.",
"Additional information.",
action,
)
assert "Original answer" in merged
assert "Additional" in merged