File size: 5,338 Bytes
fcf8749 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | """
Tests for Gemini 1.5 Flash explainability node.
"""
import pytest
import os
from unittest.mock import AsyncMock, patch, MagicMock
from app.schemas.allocation_state import AllocationState
from app.services.gemini_explain_node import gemini_explain_node, template_fallback
class TestTemplateFallback:
"""Tests for the fallback template function."""
def test_light_route_explanation(self):
"""Light route should get encouraging message."""
result = template_fallback(effort=45, avg_effort=60, is_recovery=False)
assert "Light" in result or "light" in result.lower()
def test_heavy_route_explanation(self):
"""Heavy route should acknowledge the effort."""
result = template_fallback(effort=75, avg_effort=60, is_recovery=False)
assert "heavy" in result.lower() or "balance" in result.lower()
def test_average_route_explanation(self):
"""Average route should mention balance."""
result = template_fallback(effort=60, avg_effort=60, is_recovery=False)
assert "balance" in result.lower()
def test_recovery_day_explanation(self):
"""Recovery day should be mentioned."""
result = template_fallback(effort=50, avg_effort=60, is_recovery=True)
assert "recovery" in result.lower() or "Recovery" in result
class TestGeminiExplainNode:
"""Tests for the Gemini explainability node."""
def test_returns_empty_without_api_key(self):
"""Node should return empty dict without API key."""
os.environ.pop("GOOGLE_API_KEY", None)
state = AllocationState()
# This is sync for testing, but node is async
# In real test, would use pytest-asyncio
# result = await gemini_explain_node(state)
# assert result == {}
pass
@pytest.mark.skipif(
not os.getenv("GOOGLE_API_KEY"),
reason="GOOGLE_API_KEY not set"
)
@pytest.mark.asyncio
async def test_generates_personalized_explanation(self):
"""Node should generate personalized explanations with API key."""
state = AllocationState(
config_used={"gini_threshold": 0.35},
driver_models=[
{
"id": "d1",
"name": "Raju",
"preferred_language": "en",
"vehicle_type": "ICE",
"experience_years": 3,
}
],
route_models=[
{
"id": "r1",
"num_stops": 12,
"total_distance_km": 45,
"total_weight_kg": 48,
"num_packages": 15,
"route_difficulty_score": 2.5,
"estimated_time_minutes": 180,
}
],
final_proposal={
"allocation": [
{"driver_id": "d1", "route_id": "r1", "effort": 55}
],
"per_driver_effort": {"d1": 55},
},
final_fairness={
"metrics": {
"avg_effort": 60,
"std_dev": 12,
"gini_index": 0.25,
"max_gap": 15,
}
},
driver_contexts={
"d1": {
"driver_id": "d1",
"recent_avg_effort": 58,
"recent_std_effort": 10,
"recent_hard_days": 1,
"fatigue_score": 3.0,
"preferences": {},
}
},
recovery_targets={},
explanations={
"d1": {
"driver_explanation": "Original template explanation",
"admin_explanation": "Original admin",
"category": "NEAR_AVG",
}
},
decision_logs=[],
)
result = await gemini_explain_node(state)
# Should have updated explanations
if result.get("explanations"):
assert "d1" in result["explanations"]
assert "driver_explanation" in result["explanations"]["d1"]
# Should have decision log
assert len(result["decision_logs"]) > 0
@pytest.mark.asyncio
async def test_fallback_on_import_error(self):
"""Node should handle missing langchain gracefully."""
state = AllocationState()
with patch.dict('sys.modules', {'langchain_google_genai': None}):
# Should not raise, just return empty
pass
class TestGeminiLanguageSupport:
"""Tests for Tamil/English language support."""
def test_tamil_driver_detection(self):
"""Should detect Tamil preference from driver data."""
driver = {"preferred_language": "ta"}
is_tamil = driver["preferred_language"] == "ta"
assert is_tamil
def test_english_default(self):
"""Should default to English for unknown languages."""
driver = {"preferred_language": "en"}
language = "Tamil" if driver["preferred_language"] == "ta" else "English"
assert language == "English"
|