prediqai / tests /api /test_api_train_predict_validate.py
ganesh-vilje's picture
Deploy to Hugging Face Main
f8f02c0
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.api.v1.train import train_router
from app.api.v1.validation import validation_router
from app.api.v1.prediction import prediction_router
from app.api.dependencies import get_current_username
app = FastAPI()
app.include_router(train_router, prefix="/v1")
app.include_router(validation_router, prefix="/v1")
app.include_router(prediction_router, prefix="/v1")
@pytest.fixture
def ml_client():
from app.api.v1.train import acccess_token_bearer
# Override authentication and authorization
app.dependency_overrides[acccess_token_bearer] = lambda: {"user_id": "test_user"}
app.dependency_overrides[get_current_username] = lambda: "test_user"
with TestClient(app) as client:
yield client
app.dependency_overrides.clear()
# --- Mocking Services ---
@pytest.fixture
def mock_train_service():
with patch("app.api.v1.train.training_service") as mock_service:
yield mock_service
@pytest.fixture
def mock_validation_service():
with patch("app.api.v1.validation.validation_service") as mock_service:
yield mock_service
@pytest.fixture
def mock_prediction_service():
with patch("app.api.v1.prediction.prediction_service") as mock_service:
yield mock_service
def test_start_training(ml_client, mock_train_service):
# Setup mock return
mock_train_result = MagicMock()
mock_train_result.model_dump.return_value = {
"message": "started",
"training_id": "train_123",
"status": "running"
}
mock_train_service.train = AsyncMock(return_value=mock_train_result)
payload = {
"compliance_type": "firco",
"model_name": "infinity",
"config": "{}",
"preprocess_id": "prep_1",
"version": "latest"
}
response = ml_client.post("/v1/train", data=payload)
if response.status_code != 200:
print(response.json())
assert response.status_code == 200
data = response.json()
assert data["training_id"] == "train_123"
assert data["status"] == "running"
mock_train_service.train.assert_called_once()
def test_start_training_pipeline_mode(ml_client, mock_train_service):
"""Pipeline mode: file + columns + target_columns triggers train_pipeline()"""
mock_train_result = MagicMock()
mock_train_result.model_dump.return_value = {
"message": "started",
"training_id": "train_pipeline_123",
"status": "running"
}
mock_train_service.train_pipeline = AsyncMock(return_value=mock_train_result)
payload = {
"compliance_type": "non_firco",
"model_name": "nexus",
"config": "{}",
"columns": '{"age":"numerical","gender":"categorical"}',
"target_columns": '["income"]',
"language": "en",
}
files = {"file": ("dataset.csv", b"age,gender,income\n25,M,high\n30,F,low", "text/csv")}
response = ml_client.post("/v1/train", data=payload, files=files)
if response.status_code != 200:
print(response.json())
assert response.status_code == 200
data = response.json()
assert data["training_id"] == "train_pipeline_123"
assert data["status"] == "running"
mock_train_service.train_pipeline.assert_called_once()
def test_start_training_pipeline_mode_missing_columns(ml_client, mock_train_service):
"""Pipeline mode without columns should return 400"""
payload = {
"compliance_type": "non_firco",
"model_name": "nexus",
"target_columns": '["income"]',
}
files = {"file": ("dataset.csv", b"age,income\n25,high", "text/csv")}
response = ml_client.post("/v1/train", data=payload, files=files)
assert response.status_code == 400
assert "columns" in response.json()["detail"]
def test_start_training_pipeline_mode_missing_target_columns(ml_client, mock_train_service):
"""Pipeline mode without target_columns should return 400"""
payload = {
"compliance_type": "non_firco",
"model_name": "nexus",
"columns": '{"age":"numerical"}',
}
files = {"file": ("dataset.csv", b"age,income\n25,high", "text/csv")}
response = ml_client.post("/v1/train", data=payload, files=files)
assert response.status_code == 400
assert "target_columns" in response.json()["detail"]
def test_start_training_no_file_no_preprocess_id(ml_client, mock_train_service):
"""Neither file nor preprocess_id should return 400"""
payload = {
"compliance_type": "firco",
"model_name": "infinity",
}
response = ml_client.post("/v1/train", data=payload)
assert response.status_code == 400
assert "preprocess_id" in response.json()["detail"] or "file" in response.json()["detail"]
def test_list_all_training_runs(ml_client, mock_train_service):
mock_train_service.count_by_details = AsyncMock(return_value=1)
mock_train_service.list_by_user = AsyncMock(return_value=[])
with patch("app.api.v1.train.format_universal_runs_response", new_callable=AsyncMock) as mock_format:
mock_format.return_value = {"training_runs": []}
response = ml_client.get("/v1/history/training-runs/")
assert response.status_code == 200
assert "data" in response.json()
assert "pagination" in response.json()
def test_validate_model_api(ml_client, mock_validation_service):
mock_val_result = MagicMock()
mock_val_result.model_dump.return_value = {
"message": "started",
"validation_id": "val_123",
"metrics": {"accuracy": 0.95},
"training_id": "latest"
}
mock_validation_service.validate = AsyncMock(return_value=mock_val_result)
data = {
"compliance_type": "firco",
"model_name": "infinity",
"version": "latest",
"number_of_reasonings": -1
}
files = {"file": ("test.csv", b"dummy,data", "text/csv")}
response = ml_client.post("/v1/validate", data=data, files=files)
assert response.status_code == 200
res_data = response.json()
assert res_data["validation_id"] == "val_123"
assert res_data["metrics"]["accuracy"] == 0.95
mock_validation_service.validate.assert_called_once()
def test_predict_api(ml_client, mock_prediction_service):
mock_pred_result = MagicMock()
mock_pred_result.model_dump.return_value = {
"message": "started",
"prediction_id": "pred_123",
"predictions": []
}
mock_prediction_service.predict = AsyncMock(return_value=mock_pred_result)
data = {
"compliance_type": "firco",
"model_name": "infinity",
"source_type": "file",
"version": "latest",
"number_of_reasonings": -1
}
files = {"file": ("test.csv", b"dummy,data", "text/csv")}
response = ml_client.post("/v1/predict", data=data, files=files)
assert response.status_code == 200
res_data = response.json()
assert res_data["prediction_id"] == "pred_123"
mock_prediction_service.predict.assert_called_once()