|
|
"""Unit tests for HFInferenceJudgeHandler.""" |
|
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch |
|
|
|
|
|
import pytest |
|
|
|
|
|
from src.agent_factory.judges import HFInferenceJudgeHandler |
|
|
from src.utils.models import Citation, Evidence |
|
|
|
|
|
|
|
|
@pytest.mark.unit |
|
|
class TestHFInferenceJudgeHandler: |
|
|
"""Tests for HFInferenceJudgeHandler.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def mock_client(self): |
|
|
"""Mock HuggingFace InferenceClient.""" |
|
|
with patch("src.agent_factory.judges.InferenceClient") as mock: |
|
|
client_instance = MagicMock() |
|
|
mock.return_value = client_instance |
|
|
yield client_instance |
|
|
|
|
|
@pytest.fixture |
|
|
def handler(self, mock_client): |
|
|
"""Create a handler instance with mocked client.""" |
|
|
return HFInferenceJudgeHandler() |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_assess_success(self, handler, mock_client): |
|
|
"""Test successful assessment with primary model.""" |
|
|
import json |
|
|
|
|
|
|
|
|
data = { |
|
|
"details": { |
|
|
"mechanism_score": 8, |
|
|
"mechanism_reasoning": "Good mechanism", |
|
|
"clinical_evidence_score": 7, |
|
|
"clinical_reasoning": "Good clinical", |
|
|
"drug_candidates": ["Drug A"], |
|
|
"key_findings": ["Finding 1"], |
|
|
}, |
|
|
"sufficient": True, |
|
|
"confidence": 0.85, |
|
|
"recommendation": "synthesize", |
|
|
"next_search_queries": [], |
|
|
"reasoning": ( |
|
|
"Sufficient evidence provided to support the hypothesis with high confidence." |
|
|
), |
|
|
} |
|
|
|
|
|
|
|
|
mock_message = MagicMock() |
|
|
mock_message.content = f"""Here is the analysis: |
|
|
```json |
|
|
{json.dumps(data)} |
|
|
```""" |
|
|
mock_choice = MagicMock() |
|
|
mock_choice.message = mock_message |
|
|
mock_response = MagicMock() |
|
|
mock_response.choices = [mock_choice] |
|
|
|
|
|
|
|
|
with patch("asyncio.get_running_loop") as mock_loop: |
|
|
mock_loop.return_value.run_in_executor = AsyncMock(return_value=mock_response) |
|
|
|
|
|
evidence = [ |
|
|
Evidence( |
|
|
content="test", citation=Citation(source="pubmed", title="t", url="u", date="d") |
|
|
) |
|
|
] |
|
|
result = await handler.assess("test question", evidence) |
|
|
|
|
|
assert result.sufficient is True |
|
|
assert result.confidence == 0.85 |
|
|
assert result.details.drug_candidates == ["Drug A"] |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_assess_fallback_logic(self, handler, mock_client): |
|
|
"""Test fallback to secondary model when primary fails.""" |
|
|
|
|
|
|
|
|
with patch("asyncio.get_running_loop"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
side_effect = [ |
|
|
Exception("Model 1 failed"), |
|
|
Exception("Model 2 failed"), |
|
|
Exception("Model 3 failed"), |
|
|
] |
|
|
with patch.object(handler, "_call_with_retry", side_effect=side_effect) as mock_call: |
|
|
evidence = [] |
|
|
result = await handler.assess("test", evidence) |
|
|
|
|
|
|
|
|
assert mock_call.call_count == 3 |
|
|
|
|
|
assert result.sufficient is False |
|
|
assert "failed" in result.reasoning.lower() or "error" in result.reasoning.lower() |
|
|
|
|
|
def test_extract_json_robustness(self, handler): |
|
|
"""Test JSON extraction with various inputs.""" |
|
|
|
|
|
|
|
|
assert handler._extract_json('{"a": 1}') == {"a": 1} |
|
|
|
|
|
|
|
|
assert handler._extract_json('```json\n{"a": 1}\n```') == {"a": 1} |
|
|
|
|
|
|
|
|
text = """ |
|
|
Sure, here is the JSON: |
|
|
{ |
|
|
"a": 1, |
|
|
"b": { |
|
|
"c": 2 |
|
|
} |
|
|
} |
|
|
Hope that helps! |
|
|
""" |
|
|
assert handler._extract_json(text) == {"a": 1, "b": {"c": 2}} |
|
|
|
|
|
|
|
|
nested = '{"a": {"b": "}"}}' |
|
|
assert handler._extract_json(nested) == {"a": {"b": "}"}} |
|
|
|
|
|
|
|
|
assert handler._extract_json("Not JSON") is None |
|
|
assert handler._extract_json("{Incomplete") is None |
|
|
|