FairRelay / brain /tests /test_gemini_explain.py
MouleeswaranM's picture
Upload folder using huggingface_hub
fcf8749 verified
"""
Tests for Gemini 1.5 Flash explainability node.
"""
import pytest
import os
from unittest.mock import AsyncMock, patch, MagicMock
from app.schemas.allocation_state import AllocationState
from app.services.gemini_explain_node import gemini_explain_node, template_fallback
class TestTemplateFallback:
"""Tests for the fallback template function."""
def test_light_route_explanation(self):
"""Light route should get encouraging message."""
result = template_fallback(effort=45, avg_effort=60, is_recovery=False)
assert "Light" in result or "light" in result.lower()
def test_heavy_route_explanation(self):
"""Heavy route should acknowledge the effort."""
result = template_fallback(effort=75, avg_effort=60, is_recovery=False)
assert "heavy" in result.lower() or "balance" in result.lower()
def test_average_route_explanation(self):
"""Average route should mention balance."""
result = template_fallback(effort=60, avg_effort=60, is_recovery=False)
assert "balance" in result.lower()
def test_recovery_day_explanation(self):
"""Recovery day should be mentioned."""
result = template_fallback(effort=50, avg_effort=60, is_recovery=True)
assert "recovery" in result.lower() or "Recovery" in result
class TestGeminiExplainNode:
"""Tests for the Gemini explainability node."""
def test_returns_empty_without_api_key(self):
"""Node should return empty dict without API key."""
os.environ.pop("GOOGLE_API_KEY", None)
state = AllocationState()
# This is sync for testing, but node is async
# In real test, would use pytest-asyncio
# result = await gemini_explain_node(state)
# assert result == {}
pass
@pytest.mark.skipif(
not os.getenv("GOOGLE_API_KEY"),
reason="GOOGLE_API_KEY not set"
)
@pytest.mark.asyncio
async def test_generates_personalized_explanation(self):
"""Node should generate personalized explanations with API key."""
state = AllocationState(
config_used={"gini_threshold": 0.35},
driver_models=[
{
"id": "d1",
"name": "Raju",
"preferred_language": "en",
"vehicle_type": "ICE",
"experience_years": 3,
}
],
route_models=[
{
"id": "r1",
"num_stops": 12,
"total_distance_km": 45,
"total_weight_kg": 48,
"num_packages": 15,
"route_difficulty_score": 2.5,
"estimated_time_minutes": 180,
}
],
final_proposal={
"allocation": [
{"driver_id": "d1", "route_id": "r1", "effort": 55}
],
"per_driver_effort": {"d1": 55},
},
final_fairness={
"metrics": {
"avg_effort": 60,
"std_dev": 12,
"gini_index": 0.25,
"max_gap": 15,
}
},
driver_contexts={
"d1": {
"driver_id": "d1",
"recent_avg_effort": 58,
"recent_std_effort": 10,
"recent_hard_days": 1,
"fatigue_score": 3.0,
"preferences": {},
}
},
recovery_targets={},
explanations={
"d1": {
"driver_explanation": "Original template explanation",
"admin_explanation": "Original admin",
"category": "NEAR_AVG",
}
},
decision_logs=[],
)
result = await gemini_explain_node(state)
# Should have updated explanations
if result.get("explanations"):
assert "d1" in result["explanations"]
assert "driver_explanation" in result["explanations"]["d1"]
# Should have decision log
assert len(result["decision_logs"]) > 0
@pytest.mark.asyncio
async def test_fallback_on_import_error(self):
"""Node should handle missing langchain gracefully."""
state = AllocationState()
with patch.dict('sys.modules', {'langchain_google_genai': None}):
# Should not raise, just return empty
pass
class TestGeminiLanguageSupport:
"""Tests for Tamil/English language support."""
def test_tamil_driver_detection(self):
"""Should detect Tamil preference from driver data."""
driver = {"preferred_language": "ta"}
is_tamil = driver["preferred_language"] == "ta"
assert is_tamil
def test_english_default(self):
"""Should default to English for unknown languages."""
driver = {"preferred_language": "en"}
language = "Tamil" if driver["preferred_language"] == "ta" else "English"
assert language == "English"