SPG_ML / tests /test_api.py
meetmendapara's picture
Initial commit for ML space
df31aa1
"""
Tests for Cognexa ML Service
Run with: pytest tests/test_api.py -v
"""
import pytest
from fastapi.testclient import TestClient
import sys
import os
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from main import app
client = TestClient(app)
# ============================================================================
# Health Endpoint Tests
# ============================================================================
class TestHealthEndpoints:
"""Tests for health check endpoints"""
def test_root_endpoint(self):
"""Test root endpoint returns service info"""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "Cognexa ML Service"
assert "version" in data
assert data["status"] == "running"
assert "endpoints" in data
def test_health_endpoint(self):
"""Test health check endpoint"""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "timestamp" in data
assert data["service"] == "cognexa-ml"
# ============================================================================
# Personality Analysis Tests
# ============================================================================
class TestPersonalityAnalysis:
"""Tests for personality analysis endpoints"""
def test_analyze_personality_balanced(self):
"""Test personality analysis with balanced traits"""
response = client.post(
"/api/personality/analyze",
json={
"openness": 50,
"conscientiousness": 50,
"extraversion": 50,
"agreeableness": 50,
"neuroticism": 50
}
)
assert response.status_code == 200
data = response.json()
assert "personality_type" in data
assert "strengths" in data
assert "weaknesses" in data
assert "work_style" in data
assert "recommendations" in data
assert "cognitive_style" in data
assert "team_role" in data
def test_analyze_personality_high_conscientiousness(self):
"""Test personality with high conscientiousness"""
response = client.post(
"/api/personality/analyze",
json={
"openness": 60,
"conscientiousness": 85,
"extraversion": 40,
"agreeableness": 55,
"neuroticism": 30
}
)
assert response.status_code == 200
data = response.json()
# Should have organization-related strengths
strengths_lower = " ".join(data["strengths"]).lower() if isinstance(data["strengths"], list) else data["strengths"].lower()
assert "organization" in strengths_lower or "reliable" in strengths_lower
def test_analyze_personality_high_extraversion(self):
"""Test personality with high extraversion"""
response = client.post(
"/api/personality/analyze",
json={
"openness": 65,
"conscientiousness": 50,
"extraversion": 85,
"agreeableness": 70,
"neuroticism": 35
}
)
assert response.status_code == 200
data = response.json()
# Should have communication-related strengths
strengths_lower = " ".join(data["strengths"]).lower() if isinstance(data["strengths"], list) else data["strengths"].lower()
assert "communication" in strengths_lower or "collaborative" in strengths_lower
def test_analyze_personality_validation(self):
"""Test personality analysis input validation"""
# Test with value out of range
response = client.post(
"/api/personality/analyze",
json={
"openness": 150, # Invalid
"conscientiousness": 50,
"extraversion": 50,
"agreeableness": 50,
"neuroticism": 50
}
)
assert response.status_code == 422 # Validation error
def test_analyze_personality_traits_analysis(self):
"""Test that traits analysis is included"""
response = client.post(
"/api/personality/analyze",
json={
"openness": 75,
"conscientiousness": 60,
"extraversion": 45,
"agreeableness": 65,
"neuroticism": 40
}
)
assert response.status_code == 200
data = response.json()
assert "traits_analysis" in data
traits = data["traits_analysis"]
assert "openness" in traits
assert traits["openness"]["level"] == "high"
assert "description" in traits["openness"]
# ============================================================================
# Task Prediction Tests
# ============================================================================
class TestTaskPrediction:
"""Tests for task prediction endpoints"""
def test_predict_simple_task(self):
"""Test prediction for a simple task"""
response = client.post(
"/api/tasks/predict",
json={
"title": "Review document",
"description": "Review the quarterly report",
"category": "WORK",
"priority": "MEDIUM",
"estimated_duration": 30,
"complexity": 2
}
)
assert response.status_code == 200
data = response.json()
assert "completion_probability" in data
assert 0 <= data["completion_probability"] <= 1
assert "difficulty_level" in data
assert data["difficulty_level"] in ["EASY", "MODERATE", "HARD"]
assert "stress_level" in data
assert 1 <= data["stress_level"] <= 10
assert "predicted_duration" in data
assert "recommendations" in data
def test_predict_complex_task(self):
"""Test prediction for a complex task"""
response = client.post(
"/api/tasks/predict",
json={
"title": "Complete research paper",
"description": "Write and submit the final research paper",
"category": "ACADEMIC",
"priority": "HIGH",
"estimated_duration": 480,
"complexity": 5,
"due_date": "2024-12-20T23:59:59Z"
}
)
assert response.status_code == 200
data = response.json()
# Complex task should have higher difficulty
assert data["difficulty_level"] in ["MODERATE", "HARD"]
# Should have risk factors for complex task
assert "risk_factors" in data
def test_predict_with_personality(self):
"""Test prediction with personality data"""
response = client.post(
"/api/tasks/predict",
json={
"title": "Team presentation",
"category": "WORK",
"priority": "HIGH",
"estimated_duration": 60,
"complexity": 3,
"personality": {
"openness": 70,
"conscientiousness": 80,
"extraversion": 65,
"agreeableness": 60,
"neuroticism": 30
}
}
)
assert response.status_code == 200
data = response.json()
# High conscientiousness should improve probability
assert data["completion_probability"] >= 0.5
def test_predict_urgent_task(self):
"""Test prediction for urgent task"""
response = client.post(
"/api/tasks/predict",
json={
"title": "Emergency bug fix",
"category": "WORK",
"priority": "URGENT",
"estimated_duration": 120,
"complexity": 4
}
)
assert response.status_code == 200
data = response.json()
# Urgent task should have higher stress
assert data["stress_level"] >= 5
# ============================================================================
# Batch Prediction Tests
# ============================================================================
class TestBatchPrediction:
"""Tests for batch prediction endpoints"""
def test_batch_predict_multiple_tasks(self):
"""Test batch prediction for multiple tasks"""
response = client.post(
"/api/tasks/batch-predict",
json={
"tasks": [
{
"title": "Task 1",
"category": "WORK",
"priority": "LOW",
"complexity": 2
},
{
"title": "Task 2",
"category": "ACADEMIC",
"priority": "HIGH",
"complexity": 4
},
{
"title": "Task 3",
"category": "PERSONAL",
"priority": "MEDIUM",
"complexity": 1
}
],
"personality": {
"openness": 60,
"conscientiousness": 70,
"extraversion": 50,
"agreeableness": 65,
"neuroticism": 40
}
}
)
assert response.status_code == 200
data = response.json()
assert "predictions" in data
assert len(data["predictions"]) == 3
assert "summary" in data
assert data["summary"]["total_tasks"] == 3
assert "average_completion_probability" in data["summary"]
assert "workload_assessment" in data["summary"]
def test_batch_predict_empty_list(self):
"""Test batch prediction with empty task list"""
response = client.post(
"/api/tasks/batch-predict",
json={
"tasks": []
}
)
assert response.status_code == 200
data = response.json()
assert data["summary"]["total_tasks"] == 0
# ============================================================================
# Explanation Tests
# ============================================================================
class TestExplanation:
"""Tests for prediction explanation endpoints"""
def test_explain_prediction(self):
"""Test prediction explanation"""
response = client.post(
"/api/predict/explain",
json={
"task": {
"title": "Complete project milestone",
"category": "WORK",
"priority": "HIGH",
"estimated_duration": 120,
"complexity": 4
},
"prediction_type": "completion"
}
)
assert response.status_code == 200
data = response.json()
assert "prediction" in data
assert "base_value" in data
assert "shap_values" in data
assert "feature_ranking" in data
assert "explanation" in data
assert "waterfall_visualization" in data
def test_explain_with_personality(self):
"""Test explanation with personality data"""
response = client.post(
"/api/predict/explain",
json={
"task": {
"title": "Study for exam",
"category": "ACADEMIC",
"priority": "HIGH",
"complexity": 4,
"personality": {
"openness": 75,
"conscientiousness": 85,
"extraversion": 40,
"agreeableness": 60,
"neuroticism": 55
}
}
}
)
assert response.status_code == 200
data = response.json()
# Should have feature ranking
assert len(data["feature_ranking"]) > 0
# ============================================================================
# Intervention Tests
# ============================================================================
class TestInterventions:
"""Tests for intervention suggestion endpoints"""
def test_suggest_interventions_high_stress(self):
"""Test intervention suggestions for high stress"""
response = client.post(
"/api/interventions/suggest",
json={
"task_id": "test-task-123",
"completion_probability": 0.4,
"stress_level": 8.5,
"current_workload": 7
}
)
assert response.status_code == 200
data = response.json()
assert "interventions" in data
assert len(data["interventions"]) > 0
assert "overall_strategy" in data
# Should include stress management interventions
intervention_types = [i["type"] for i in data["interventions"]]
assert any("STRESS" in t for t in intervention_types)
def test_suggest_interventions_low_probability(self):
"""Test intervention suggestions for low probability"""
response = client.post(
"/api/interventions/suggest",
json={
"task_id": "test-task-456",
"completion_probability": 0.3,
"stress_level": 5.0
}
)
assert response.status_code == 200
data = response.json()
# Should include task restructuring interventions
intervention_types = [i["type"] for i in data["interventions"]]
assert any("TASK" in t or "DEADLINE" in t for t in intervention_types)
# ============================================================================
# Productivity Forecast Tests
# ============================================================================
class TestProductivityForecast:
"""Tests for productivity forecasting endpoints"""
def test_forecast_productivity(self):
"""Test productivity forecasting"""
response = client.post(
"/api/productivity/forecast",
json={
"user_id": "test-user-123",
"historical_data": [
{"day_of_week": "Monday", "productivity": 70},
{"day_of_week": "Tuesday", "productivity": 85},
{"day_of_week": "Wednesday", "productivity": 90},
{"day_of_week": "Thursday", "productivity": 80},
{"day_of_week": "Friday", "productivity": 65}
],
"forecast_days": 7
}
)
assert response.status_code == 200
data = response.json()
assert "daily_forecasts" in data
assert len(data["daily_forecasts"]) == 7
assert "patterns_identified" in data
assert "optimization_suggestions" in data
# Check forecast structure
for forecast in data["daily_forecasts"]:
assert "date" in forecast
assert "day" in forecast
assert "predicted_productivity" in forecast
assert "recommended_task_count" in forecast
def test_forecast_with_personality(self):
"""Test forecasting with personality data"""
response = client.post(
"/api/productivity/forecast",
json={
"user_id": "test-user-456",
"historical_data": [],
"forecast_days": 5,
"personality": {
"openness": 60,
"conscientiousness": 80,
"extraversion": 50,
"agreeableness": 65,
"neuroticism": 35
}
}
)
assert response.status_code == 200
data = response.json()
assert len(data["daily_forecasts"]) == 5
# ============================================================================
# Model Info Tests
# ============================================================================
class TestModelInfo:
"""Tests for model information endpoint"""
def test_get_model_info(self):
"""Test getting model information"""
response = client.get("/api/model/info")
assert response.status_code == 200
data = response.json()
assert "models" in data
assert "personality_analysis" in data["models"]
assert "task_prediction" in data["models"]
assert "last_updated" in data
assert "accuracy_metrics" in data
# ============================================================================
# Integration Tests
# ============================================================================
class TestIntegration:
"""Integration tests for complete workflows"""
def test_personality_to_task_flow(self):
"""Test complete flow from personality analysis to task prediction"""
# First, analyze personality
personality_response = client.post(
"/api/personality/analyze",
json={
"openness": 70,
"conscientiousness": 75,
"extraversion": 55,
"agreeableness": 60,
"neuroticism": 40
}
)
assert personality_response.status_code == 200
# Then, predict task with that personality
task_response = client.post(
"/api/tasks/predict",
json={
"title": "Complete assignment",
"category": "ACADEMIC",
"priority": "HIGH",
"complexity": 3,
"personality": {
"openness": 70,
"conscientiousness": 75,
"extraversion": 55,
"agreeableness": 60,
"neuroticism": 40
}
}
)
assert task_response.status_code == 200
# Finally, get interventions if needed
prediction = task_response.json()
if prediction["completion_probability"] < 0.7:
intervention_response = client.post(
"/api/interventions/suggest",
json={
"task_id": "test-task",
"completion_probability": prediction["completion_probability"],
"stress_level": prediction["stress_level"]
}
)
assert intervention_response.status_code == 200
if __name__ == "__main__":
pytest.main([__file__, "-v"])